This commit is contained in:
2025-10-12 18:04:33 -04:00
parent e6db6bcc68
commit d1f845b068
3 changed files with 1469 additions and 499 deletions

1605
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,8 +8,12 @@ license = "MIT OR Apache-2.0"
[dependencies]
anyhow = "1.0.100"
async-trait = "0.1"
awc = "3"
reqwest = { version = "0.12.23" , features = ["json", "rustls-tls"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1.0.145"
toml = "0.9.5"
urlencoding = "2"
tokio = { version = "1.47", features = ["full"] }
once_cell = "1.19"
num_cpus = "1.16"

View File

@ -1,56 +1,24 @@
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::{fmt, fs};
use once_cell::sync::Lazy;
use num_cpus;
use reqwest::{Client, Method, header::HeaderMap };
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use tokio::sync::Semaphore;
/// A trait every API request type should implement (via macro).
/// Provides info about the HTTP protocol, URLs, and supported methods.
pub trait HasHttp {
/// Methods this request supports (GET, POST, etc.)
fn http_methods() -> &'static [&'static str];
/// ---------------------------------------------------------------------------
/// GLOBAL CONCURRENCY CONTROL
/// ---------------------------------------------------------------------------
/// Production endpoint (base URL).
fn live_url() -> &'static str;
/// Sandbox endpoint (base URL).
fn sandbox_url() -> &'static str;
}
/// Trait implemented by all request types that can be sent to an HTTP API.
/// Usually derived via `#[http(...)]`.
#[async_trait]
pub trait Queryable: HasHttp + Send + Sync {
type R;
type E: std::error::Error + Send + Sync + 'static; // Keep trait bound
async fn send(
&self,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
headers: Option<Vec<(&str, &str)>>,
) -> Result<Self::R>;
}
/// Trait implemented by your auto-generated API dispatcher enums.
/// This allows you to batch or route requests by API type.
#[async_trait]
pub trait ApiDispatch {
type R;
async fn send_all(
&self,
client: Arc<awc::Client>,
keys: Arc<Keys>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
) -> Result<Vec<Self::R>>;
}
/// Shared semaphore to limit concurrent requests (2 × CPU threads).
static SEMAPHORE: Lazy<Arc<Semaphore>> =
Lazy::new(|| Arc::new(Semaphore::new(num_cpus::get() * 2)));
/// ---------------------------------------------------------------------------
/// API KEY MANAGEMENT
/// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Key {
@ -64,17 +32,16 @@ pub struct Keys {
}
impl Keys {
/// Convenience method to get a key by name (e.g. "alpaca").
pub fn get(&self, name: &str) -> Option<&Key> {
self.keys.get(name)
}
}
/// Loads API keys from ~/.config/keys or fallback ./keys directory.
/// Each `.toml` file should contain a `Key { key, secret }` struct.
/// The filename (without `.toml`) becomes the key name in the map.
/// Load API keys from either:
/// - `$HOME/.config/keys`
/// - or local `./keys` directory
pub fn load_api_keys() -> Result<Keys> {
let home_dir = std::env::var("HOME")?;
let home_dir = std::env::var("HOME").context("Failed to read $HOME")?;
let config_dir = Path::new(&home_dir).join(".config/keys");
let fallback_dir = Path::new("keys");
@ -86,14 +53,17 @@ pub fn load_api_keys() -> Result<Keys> {
let mut keys_map = HashMap::new();
for entry in fs::read_dir(dir_path)? {
let entry = entry?;
for entry in fs::read_dir(&dir_path)
.with_context(|| format!("Failed to read keys directory: {}", dir_path.display()))?
{
let entry = entry.context("Failed to read directory entry")?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("toml") {
let contents = fs::read_to_string(&path)?;
let key: Key = toml::from_str(&contents)?;
let contents = fs::read_to_string(&path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
let key: Key = toml::from_str(&contents)
.with_context(|| format!("Failed to parse TOML: {}", path.display()))?;
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
keys_map.insert(stem.to_string(), key);
}
@ -103,15 +73,268 @@ pub fn load_api_keys() -> Result<Keys> {
Ok(Keys { keys: keys_map })
}
#[derive(Debug)]
pub struct ConversionError(pub ::std::borrow::Cow<'static, str>);
/// ---------------------------------------------------------------------------
/// CORE HTTP REQUEST HELPER
/// ---------------------------------------------------------------------------
impl fmt::Display for ConversionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
/// Shared helper to send HTTP requests with concurrency limiting, API key injection,
/// and optional custom headers.
pub async fn send_request<T, R>(
client: Arc<Client>,
keys: Arc<Keys>,
api_name: &str,
base_url: &str,
path_template: &str,
override_url: Option<&str>,
method: &str,
headers: Option<&HeaderMap>,
body: &T,
) -> Result<R>
where
T: Serialize + Send + Sync,
R: DeserializeOwned + Send + 'static,
{
// Acquire semaphore permit (limits concurrency)
let _permit = SEMAPHORE.acquire().await.expect("Semaphore poisoned");
// Build initial URL: override or base + path_template
let mut url = if let Some(o) = override_url {
o.to_string()
} else {
if base_url.ends_with('/') {
format!("{}{}", base_url.trim_end_matches('/'), path_template)
} else {
format!(
"{}/{}",
base_url.trim_end_matches('/'),
path_template.trim_start_matches('/')
)
}
};
// Apply path + query replacements (safe no-ops if none exist)
url = replace_path_params(&url, &[], body)
.context("Failed to replace path params")?;
url = append_query_params(&url, &[], body)
.context("Failed to append query params")?;
// Load API keys (pre-provided Arc<Keys>)
let (api_key, secret_key) = keys
.get(api_name)
.map(|k| (k.key.clone(), k.secret.clone()))
.unwrap_or_else(|| (String::new(), String::new()));
// Parse HTTP method into reqwest::Method
let method_obj =
Method::from_bytes(method.as_bytes()).context("Invalid HTTP method string")?;
// Build request
let mut req = client.request(method_obj.clone(), &url);
// Attach API credentials if present
if !api_key.is_empty() {
req = req
.header("APCA-API-KEY-ID", api_key)
.header("APCA-API-SECRET-KEY", secret_key);
}
// ✅ Add custom headers if provided
if let Some(extra_headers) = headers {
for (k, v) in extra_headers {
req = req.header(k, v);
}
}
// Attach JSON body only for non-GET/HEAD methods
if !(method_obj == Method::GET || method_obj == Method::HEAD) {
req = req.json(body);
}
// Send
let resp = req.send().await.context("HTTP send failed")?;
let status = resp.status();
if status.is_success() {
let parsed: R = resp.json().await.context("Failed to parse JSON response")?;
Ok(parsed)
} else {
let text = resp.text().await.unwrap_or_default();
Err(anyhow!(
"[{}] HTTP {} failed: {}",
api_name,
status,
text
))
}
}
impl std::error::Error for ConversionError {}
unsafe impl Send for ConversionError {}
unsafe impl Sync for ConversionError {}
/// ---------------------------------------------------------------------------
/// TRAITS
/// ---------------------------------------------------------------------------
pub trait HasHttp {
fn http_methods() -> &'static [&'static str];
fn live_url() -> &'static str;
fn sandbox_url() -> &'static str;
}
#[async_trait]
pub trait Queryable: Send + Sync + 'static {
type Response: Send + Sync + 'static;
/// Primary method: send the request using a provided `keys` set.
async fn send_with_keys(
&self,
client: Arc<Client>,
keys: Arc<Keys>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
headers: Option<Vec<(&str, &str)>>,
) -> Result<Self::Response>;
/// Convenience wrapper: loads keys automatically from disk.
async fn send(
&self,
client: Arc<Client>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
headers: Option<Vec<(&str, &str)>>,
) -> Result<Self::Response> {
let keys = load_api_keys().map_err(|e| anyhow!("Failed to load API keys: {}", e))?;
self.send_with_keys(client, Arc::new(keys), override_url, sandbox, method_override, headers)
.await
}
}
#[async_trait]
pub trait ApiDispatch: Send + Sync + 'static {
type Request;
type Response;
type Error;
async fn send_all(
&self,
client: Arc<Client>,
keys: Arc<Keys>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
) -> Result<Vec<anyhow::Result<Value>>>;
}
/// ---------------------------------------------------------------------------
/// API ENDPOINT STRUCTURE + HELPERS
/// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiEndpoint {
pub name: String,
pub operation_id: Option<String>,
pub method: String,
pub path_template: String,
pub path_params: Vec<String>,
pub query_params: Vec<String>,
pub live_url: String,
pub sandbox_url: String,
pub request_type: Option<String>,
pub responses: Vec<(String, String)>,
pub requires_auth: bool,
}
pub fn replace_path_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
let mut url = url.to_string();
let val = serde_json::to_value(req)?;
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in path param serialization"))?;
for p in params {
if let Some(v) = obj.get(p) {
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
url = url.replace(&format!("{{{}}}", p), &s);
}
}
Ok(url)
}
pub fn append_query_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
if params.is_empty() {
return Ok(url.to_string());
}
let val = serde_json::to_value(req)?;
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in query param serialization"))?;
let mut query_pairs = Vec::new();
for p in params {
if let Some(v) = obj.get(p) {
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
query_pairs.push(format!("{}={}", p, urlencoding::encode(&s)));
}
}
if query_pairs.is_empty() {
Ok(url.to_string())
} else if url.contains('?') {
Ok(format!("{}&{}", url, query_pairs.join("&")))
} else {
Ok(format!("{}?{}", url, query_pairs.join("&")))
}
}
/// ---------------------------------------------------------------------------
/// BLANKET ADAPTER — make any Queryable act as an ApiDispatch
/// ---------------------------------------------------------------------------
#[async_trait]
impl<T> ApiDispatch for T
where
T: Queryable + Serialize + Send + Sync,
T::Response: Serialize + Send + Sync + 'static,
{
type Request = ();
type Response = ();
type Error = anyhow::Error;
async fn send_all(
&self,
client: Arc<Client>,
keys: Arc<Keys>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
) -> Result<Vec<anyhow::Result<Value>>> {
let transport_res = self
.send_with_keys(client, keys, override_url, sandbox, method_override, None)
.await;
let res: anyhow::Result<Value> = match transport_res {
Ok(v) => Ok(serde_json::to_value(v).unwrap_or(Value::Null)),
Err(e) => Err(e),
};
Ok(vec![res])
}
}
/// ---------------------------------------------------------------------------
/// UNIVERSAL HTTP HELPER (for generated API structs)
/// ---------------------------------------------------------------------------
pub async fn http_helper<T>(
req: &T,
client: Arc<Client>,
keys: Arc<Keys>,
override_url: Option<String>,
sandbox: bool,
method_override: Option<&str>,
) -> Result<T::Response>
where
T: Queryable + Serialize + Sync,
T::Response: Send + Sync + 'static,
{
let override_str = override_url.as_deref();
req.send_with_keys(client, keys, override_str, sandbox, method_override, None)
.await
}