Compare commits
5 Commits
be481eaa37
...
0842b4d0bc
| Author | SHA1 | Date | |
|---|---|---|---|
| 0842b4d0bc | |||
| d1f845b068 | |||
| e6db6bcc68 | |||
| 86fdef4b38 | |||
| 4269799f8f |
1719
Cargo.lock
generated
1719
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -6,9 +6,14 @@ description = "Core traits and types for building API clients with http_derive"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.100"
|
||||
async-trait = "0.1"
|
||||
awc = "3"
|
||||
reqwest = { version = "0.12.23" , features = ["json", "rustls-tls"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1.0.145"
|
||||
toml = "0.9.5"
|
||||
urlencoding = "2"
|
||||
tokio = { version = "1.47", features = ["full"] }
|
||||
once_cell = "1.19"
|
||||
num_cpus = "1.16"
|
||||
|
||||
|
||||
449
src/lib.rs
449
src/lib.rs
@ -1,64 +1,24 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::fs;
|
||||
use once_cell::sync::Lazy;
|
||||
use num_cpus;
|
||||
use reqwest::{Client, Method, header::HeaderMap };
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, fs, path::Path, sync::Arc};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
/// A trait every API request type should implement (via macro).
|
||||
/// Provides info about the HTTP protocol, URLs, and supported methods.
|
||||
pub trait HasHttp {
|
||||
/// Methods this request supports (GET, POST, etc.)
|
||||
fn http_methods() -> &'static [&'static str];
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// GLOBAL CONCURRENCY CONTROL
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
/// Production endpoint (base URL).
|
||||
fn live_url() -> &'static str;
|
||||
/// Shared semaphore to limit concurrent requests (2 × CPU threads).
|
||||
static SEMAPHORE: Lazy<Arc<Semaphore>> =
|
||||
Lazy::new(|| Arc::new(Semaphore::new(num_cpus::get() * 2)));
|
||||
|
||||
/// Sandbox endpoint (base URL).
|
||||
fn sandbox_url() -> &'static str;
|
||||
}
|
||||
|
||||
/// Trait implemented by all request types that can be sent to an HTTP API.
|
||||
/// Usually derived via `#[derive(HttpRequest)]`.
|
||||
#[async_trait]
|
||||
pub trait Queryable: HasHttp + Send + Sync {
|
||||
/// The response type for this query.
|
||||
type R;
|
||||
/// The error type for this query.
|
||||
type E: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Send the query, with flexibility for environment and overrides.
|
||||
///
|
||||
/// - `override_url`: if `Some`, use this instead of live/sandbox.
|
||||
/// - `sandbox`: if true, use sandbox_url.
|
||||
/// - `method_override`: optional HTTP method override (GET, POST, etc.).
|
||||
/// - `headers`: optional headers.
|
||||
async fn send(
|
||||
&self,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::R, Self::E>;
|
||||
}
|
||||
|
||||
/// Trait for dispatching API calls in bulk or single-shot mode.
|
||||
///
|
||||
/// Implemented automatically for generated enums by the `#[api_dispatch]` macro.
|
||||
#[async_trait]
|
||||
pub trait ApiDispatch {
|
||||
/// Send all queries represented by this enum variant.
|
||||
///
|
||||
/// - `client`: A preconfigured `awc::Client`
|
||||
/// - `keys`: A shared API key store (depends on your app crate)
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: &awc::Client,
|
||||
keys: &HashMap<String, crate::Keys>, // ⚠️ placeholder, you may need to re-export Keys
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// API KEY MANAGEMENT
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Key {
|
||||
@ -72,17 +32,16 @@ pub struct Keys {
|
||||
}
|
||||
|
||||
impl Keys {
|
||||
/// Convenience method to get a key by name (e.g. "alpaca")
|
||||
pub fn get(&self, name: &str) -> Option<&Key> {
|
||||
self.keys.get(name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads API keys from ~/.config/keys or fallback ./keys directory.
|
||||
/// Each `.toml` file should contain a `Key { key, secret }` struct.
|
||||
/// The filename (without `.toml`) becomes the key name in the map.
|
||||
pub fn load_api_keys() -> Result<Keys, Box<dyn std::error::Error>> {
|
||||
let home_dir = std::env::var("HOME")?;
|
||||
/// Load API keys from either:
|
||||
/// - `$HOME/.config/keys`
|
||||
/// - or local `./keys` directory
|
||||
pub fn load_api_keys() -> Result<Keys> {
|
||||
let home_dir = std::env::var("HOME").context("Failed to read $HOME")?;
|
||||
let config_dir = Path::new(&home_dir).join(".config/keys");
|
||||
let fallback_dir = Path::new("keys");
|
||||
|
||||
@ -94,14 +53,17 @@ pub fn load_api_keys() -> Result<Keys, Box<dyn std::error::Error>> {
|
||||
|
||||
let mut keys_map = HashMap::new();
|
||||
|
||||
for entry in fs::read_dir(dir_path)? {
|
||||
let entry = entry?;
|
||||
for entry in fs::read_dir(&dir_path)
|
||||
.with_context(|| format!("Failed to read keys directory: {}", dir_path.display()))?
|
||||
{
|
||||
let entry = entry.context("Failed to read directory entry")?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("toml") {
|
||||
let contents = fs::read_to_string(&path)?;
|
||||
let key: Key = toml::from_str(&contents)?;
|
||||
|
||||
let contents = fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read file: {}", path.display()))?;
|
||||
let key: Key = toml::from_str(&contents)
|
||||
.with_context(|| format!("Failed to parse TOML: {}", path.display()))?;
|
||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
keys_map.insert(stem.to_string(), key);
|
||||
}
|
||||
@ -111,4 +73,355 @@ pub fn load_api_keys() -> Result<Keys, Box<dyn std::error::Error>> {
|
||||
Ok(Keys { keys: keys_map })
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// CORE HTTP REQUEST HELPER
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
/// Shared helper to send HTTP requests with concurrency limiting, API key injection,
|
||||
/// and optional custom headers.
|
||||
pub async fn send_request<T, R>(
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
api_name: &str,
|
||||
base_url: &str,
|
||||
path_template: &str,
|
||||
override_url: Option<&str>,
|
||||
method: &str,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &T,
|
||||
) -> Result<R>
|
||||
where
|
||||
T: Serialize + Send + Sync,
|
||||
R: DeserializeOwned + Send + 'static,
|
||||
{
|
||||
// Acquire semaphore permit (limits concurrency)
|
||||
let _permit = SEMAPHORE.acquire().await.expect("Semaphore poisoned");
|
||||
|
||||
// Build initial URL: override or base + path_template
|
||||
let mut url = if let Some(o) = override_url {
|
||||
o.to_string()
|
||||
} else {
|
||||
if base_url.ends_with('/') {
|
||||
format!("{}{}", base_url.trim_end_matches('/'), path_template)
|
||||
} else {
|
||||
format!(
|
||||
"{}/{}",
|
||||
base_url.trim_end_matches('/'),
|
||||
path_template.trim_start_matches('/')
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Apply path + query replacements (safe no-ops if none exist)
|
||||
url = replace_path_params(&url, &[], body)
|
||||
.context("Failed to replace path params")?;
|
||||
url = append_query_params(&url, &[], body)
|
||||
.context("Failed to append query params")?;
|
||||
|
||||
// Load API keys (pre-provided Arc<Keys>)
|
||||
let (api_key, secret_key) = keys
|
||||
.get(api_name)
|
||||
.map(|k| (k.key.clone(), k.secret.clone()))
|
||||
.unwrap_or_else(|| (String::new(), String::new()));
|
||||
|
||||
// Parse HTTP method into reqwest::Method
|
||||
let method_obj =
|
||||
Method::from_bytes(method.as_bytes()).context("Invalid HTTP method string")?;
|
||||
|
||||
// Build request
|
||||
let mut req = client.request(method_obj.clone(), &url);
|
||||
|
||||
// Attach API credentials if present
|
||||
if !api_key.is_empty() {
|
||||
req = req
|
||||
.header("APCA-API-KEY-ID", api_key)
|
||||
.header("APCA-API-SECRET-KEY", secret_key);
|
||||
}
|
||||
|
||||
// ✅ Add custom headers if provided
|
||||
if let Some(extra_headers) = headers {
|
||||
for (k, v) in extra_headers {
|
||||
req = req.header(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach JSON body only for non-GET/HEAD methods
|
||||
if !(method_obj == Method::GET || method_obj == Method::HEAD) {
|
||||
req = req.json(body);
|
||||
}
|
||||
|
||||
// Send
|
||||
let resp = req.send().await.context("HTTP send failed")?;
|
||||
let status = resp.status();
|
||||
|
||||
if status.is_success() {
|
||||
let parsed: R = resp.json().await.context("Failed to parse JSON response")?;
|
||||
Ok(parsed)
|
||||
} else {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow!(
|
||||
"[{}] HTTP {} failed: {}",
|
||||
api_name,
|
||||
status,
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// TRAITS
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
pub trait HasHttp {
|
||||
fn http_methods() -> &'static [&'static str];
|
||||
fn live_url() -> &'static str;
|
||||
fn sandbox_url() -> &'static str;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Queryable: Send + Sync + 'static {
|
||||
type Response: Send + Sync + 'static;
|
||||
|
||||
/// Primary method: send the request using a provided `keys` set.
|
||||
async fn send_with_keys(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::Response>;
|
||||
|
||||
/// Convenience wrapper: loads keys automatically from disk.
|
||||
async fn send(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::Response> {
|
||||
let keys = load_api_keys().map_err(|e| anyhow!("Failed to load API keys: {}", e))?;
|
||||
self.send_with_keys(client, Arc::new(keys), override_url, sandbox, method_override, headers)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ApiDispatch: Send + Sync + 'static {
|
||||
type Request;
|
||||
type Response;
|
||||
type Error;
|
||||
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<Vec<anyhow::Result<Value>>>;
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// API ENDPOINT STRUCTURE + HELPERS
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiEndpoint {
|
||||
pub name: String,
|
||||
pub operation_id: Option<String>,
|
||||
pub method: String,
|
||||
pub path_template: String,
|
||||
pub path_params: Vec<String>,
|
||||
pub query_params: Vec<String>,
|
||||
pub live_url: String,
|
||||
pub sandbox_url: String,
|
||||
pub request_type: Option<String>,
|
||||
pub responses: Vec<(String, String)>,
|
||||
pub requires_auth: bool,
|
||||
}
|
||||
|
||||
pub fn replace_path_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
|
||||
let mut url = url.to_string();
|
||||
let val = serde_json::to_value(req)?;
|
||||
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in path param serialization"))?;
|
||||
|
||||
for p in params {
|
||||
if let Some(v) = obj.get(p) {
|
||||
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
|
||||
url = url.replace(&format!("{{{}}}", p), &s);
|
||||
}
|
||||
}
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
pub fn append_query_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
|
||||
if params.is_empty() {
|
||||
return Ok(url.to_string());
|
||||
}
|
||||
|
||||
let val = serde_json::to_value(req)?;
|
||||
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in query param serialization"))?;
|
||||
|
||||
let mut query_pairs = Vec::new();
|
||||
for p in params {
|
||||
if let Some(v) = obj.get(p) {
|
||||
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
|
||||
query_pairs.push(format!("{}={}", p, urlencoding::encode(&s)));
|
||||
}
|
||||
}
|
||||
|
||||
if query_pairs.is_empty() {
|
||||
Ok(url.to_string())
|
||||
} else if url.contains('?') {
|
||||
Ok(format!("{}&{}", url, query_pairs.join("&")))
|
||||
} else {
|
||||
Ok(format!("{}?{}", url, query_pairs.join("&")))
|
||||
}
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// BLANKET ADAPTER — make any Queryable act as an ApiDispatch
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
#[async_trait]
|
||||
impl<T> ApiDispatch for T
|
||||
where
|
||||
T: Queryable + Serialize + Send + Sync,
|
||||
T::Response: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Request = ();
|
||||
type Response = ();
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<Vec<anyhow::Result<Value>>> {
|
||||
let transport_res = self
|
||||
.send_with_keys(client, keys, override_url, sandbox, method_override, None)
|
||||
.await;
|
||||
|
||||
let res: anyhow::Result<Value> = match transport_res {
|
||||
Ok(v) => Ok(serde_json::to_value(v).unwrap_or(Value::Null)),
|
||||
Err(e) => Err(e),
|
||||
};
|
||||
|
||||
Ok(vec![res])
|
||||
}
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// UNIVERSAL HTTP HELPER (for generated API structs)
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn http_helper<T>(
|
||||
req: &T,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<String>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<T::Response>
|
||||
where
|
||||
T: Queryable + Serialize + Sync,
|
||||
T::Response: Send + Sync + 'static,
|
||||
{
|
||||
let override_str = override_url.as_deref();
|
||||
req.send_with_keys(client, keys, override_str, sandbox, method_override, None)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute a list of ApiDispatch objects concurrently using a shared Client and Keys.
|
||||
/// - `dispatches`: vector of dispatchable objects (Arc<dyn ApiDispatch>).
|
||||
/// - `client`: shared reqwest client (Arc) — cloned for each task cheaply.
|
||||
/// - `keys`: shared keys (Arc).
|
||||
/// - `override_url` and `sandbox` forwarded to each dispatch call.
|
||||
/// - `concurrency_limit`: optional cap; if None, uses SEMAPHORE global limit.
|
||||
///
|
||||
/// Returns flattened Vec<anyhow::Result<Value>> (each dispatch may yield multiple Values).
|
||||
/// Execute a list of ApiDispatch objects concurrently using a shared Client and Keys.
|
||||
pub async fn execute_dispatches_concurrent(
|
||||
dispatches: Vec<
|
||||
Arc<
|
||||
dyn ApiDispatch<
|
||||
Request = (),
|
||||
Response = serde_json::Value,
|
||||
Error = anyhow::Error,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
concurrency_limit: Option<usize>,
|
||||
) -> anyhow::Result<Vec<anyhow::Result<serde_json::Value>>> {
|
||||
use tokio::task::JoinSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Optional per-call semaphore (create local one if a specific limit is provided)
|
||||
let local_sem = concurrency_limit.map(|n| Arc::new(Semaphore::new(n)));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
for dispatch in dispatches.into_iter() {
|
||||
let client = client.clone();
|
||||
let keys = keys.clone();
|
||||
let override_url = override_url.map(|s| s.to_string());
|
||||
let local_sem_clone = local_sem.clone();
|
||||
let dispatch_arc = dispatch.clone();
|
||||
|
||||
set.spawn(async move {
|
||||
// ✅ FIXED: use acquire_owned() so permit owns the semaphore Arc
|
||||
let _permit = if let Some(ls) = local_sem_clone {
|
||||
let sem = Arc::clone(&ls);
|
||||
let p = sem.acquire_owned().await.expect("local semaphore poisoned");
|
||||
Some(p)
|
||||
} else {
|
||||
let sem = Arc::clone(&SEMAPHORE);
|
||||
let p = sem.acquire_owned().await.expect("global semaphore poisoned");
|
||||
Some(p)
|
||||
};
|
||||
|
||||
// Call send_all (each dispatch returns Vec<anyhow::Result<Value>>)
|
||||
let res = dispatch_arc
|
||||
.send_all(
|
||||
client,
|
||||
keys,
|
||||
override_url.as_deref(),
|
||||
sandbox,
|
||||
None, // method_override not used here
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(vec_vals) => Ok(vec_vals),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Collect results, flatten Vec<Vec<Result<Value>>> -> Vec<Result<Value>>
|
||||
let mut out: Vec<anyhow::Result<serde_json::Value>> = Vec::new();
|
||||
|
||||
while let Some(join_res) = set.join_next().await {
|
||||
match join_res {
|
||||
Ok(inner_res) => match inner_res {
|
||||
Ok(vec_vals) => out.extend(vec_vals),
|
||||
Err(e) => out.push(Err(e)),
|
||||
},
|
||||
Err(join_err) => {
|
||||
out.push(Err(anyhow::anyhow!("task join error: {}", join_err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user