ud
This commit is contained in:
1605
Cargo.lock
generated
1605
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,8 +8,12 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
anyhow = "1.0.100"
|
||||
async-trait = "0.1"
|
||||
awc = "3"
|
||||
reqwest = { version = "0.12.23" , features = ["json", "rustls-tls"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1.0.145"
|
||||
toml = "0.9.5"
|
||||
urlencoding = "2"
|
||||
tokio = { version = "1.47", features = ["full"] }
|
||||
once_cell = "1.19"
|
||||
num_cpus = "1.16"
|
||||
|
||||
|
||||
357
src/lib.rs
357
src/lib.rs
@ -1,56 +1,24 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, fs};
|
||||
use once_cell::sync::Lazy;
|
||||
use num_cpus;
|
||||
use reqwest::{Client, Method, header::HeaderMap };
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, fs, path::Path, sync::Arc};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
/// A trait every API request type should implement (via macro).
|
||||
/// Provides info about the HTTP protocol, URLs, and supported methods.
|
||||
pub trait HasHttp {
|
||||
/// Methods this request supports (GET, POST, etc.)
|
||||
fn http_methods() -> &'static [&'static str];
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// GLOBAL CONCURRENCY CONTROL
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
/// Production endpoint (base URL).
|
||||
fn live_url() -> &'static str;
|
||||
|
||||
/// Sandbox endpoint (base URL).
|
||||
fn sandbox_url() -> &'static str;
|
||||
}
|
||||
|
||||
/// Trait implemented by all request types that can be sent to an HTTP API.
|
||||
/// Usually derived via `#[http(...)]`.
|
||||
#[async_trait]
|
||||
pub trait Queryable: HasHttp + Send + Sync {
|
||||
type R;
|
||||
type E: std::error::Error + Send + Sync + 'static; // Keep trait bound
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::R>;
|
||||
}
|
||||
|
||||
/// Trait implemented by your auto-generated API dispatcher enums.
|
||||
/// This allows you to batch or route requests by API type.
|
||||
#[async_trait]
|
||||
pub trait ApiDispatch {
|
||||
type R;
|
||||
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: Arc<awc::Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<Vec<Self::R>>;
|
||||
}
|
||||
/// Shared semaphore to limit concurrent requests (2 × CPU threads).
|
||||
static SEMAPHORE: Lazy<Arc<Semaphore>> =
|
||||
Lazy::new(|| Arc::new(Semaphore::new(num_cpus::get() * 2)));
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// API KEY MANAGEMENT
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Key {
|
||||
@ -64,17 +32,16 @@ pub struct Keys {
|
||||
}
|
||||
|
||||
impl Keys {
|
||||
/// Convenience method to get a key by name (e.g. "alpaca").
|
||||
pub fn get(&self, name: &str) -> Option<&Key> {
|
||||
self.keys.get(name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads API keys from ~/.config/keys or fallback ./keys directory.
|
||||
/// Each `.toml` file should contain a `Key { key, secret }` struct.
|
||||
/// The filename (without `.toml`) becomes the key name in the map.
|
||||
/// Load API keys from either:
|
||||
/// - `$HOME/.config/keys`
|
||||
/// - or local `./keys` directory
|
||||
pub fn load_api_keys() -> Result<Keys> {
|
||||
let home_dir = std::env::var("HOME")?;
|
||||
let home_dir = std::env::var("HOME").context("Failed to read $HOME")?;
|
||||
let config_dir = Path::new(&home_dir).join(".config/keys");
|
||||
let fallback_dir = Path::new("keys");
|
||||
|
||||
@ -86,14 +53,17 @@ pub fn load_api_keys() -> Result<Keys> {
|
||||
|
||||
let mut keys_map = HashMap::new();
|
||||
|
||||
for entry in fs::read_dir(dir_path)? {
|
||||
let entry = entry?;
|
||||
for entry in fs::read_dir(&dir_path)
|
||||
.with_context(|| format!("Failed to read keys directory: {}", dir_path.display()))?
|
||||
{
|
||||
let entry = entry.context("Failed to read directory entry")?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("toml") {
|
||||
let contents = fs::read_to_string(&path)?;
|
||||
let key: Key = toml::from_str(&contents)?;
|
||||
|
||||
let contents = fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read file: {}", path.display()))?;
|
||||
let key: Key = toml::from_str(&contents)
|
||||
.with_context(|| format!("Failed to parse TOML: {}", path.display()))?;
|
||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
keys_map.insert(stem.to_string(), key);
|
||||
}
|
||||
@ -103,15 +73,268 @@ pub fn load_api_keys() -> Result<Keys> {
|
||||
Ok(Keys { keys: keys_map })
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConversionError(pub ::std::borrow::Cow<'static, str>);
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// CORE HTTP REQUEST HELPER
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
impl fmt::Display for ConversionError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
/// Shared helper to send HTTP requests with concurrency limiting, API key injection,
|
||||
/// and optional custom headers.
|
||||
pub async fn send_request<T, R>(
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
api_name: &str,
|
||||
base_url: &str,
|
||||
path_template: &str,
|
||||
override_url: Option<&str>,
|
||||
method: &str,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &T,
|
||||
) -> Result<R>
|
||||
where
|
||||
T: Serialize + Send + Sync,
|
||||
R: DeserializeOwned + Send + 'static,
|
||||
{
|
||||
// Acquire semaphore permit (limits concurrency)
|
||||
let _permit = SEMAPHORE.acquire().await.expect("Semaphore poisoned");
|
||||
|
||||
// Build initial URL: override or base + path_template
|
||||
let mut url = if let Some(o) = override_url {
|
||||
o.to_string()
|
||||
} else {
|
||||
if base_url.ends_with('/') {
|
||||
format!("{}{}", base_url.trim_end_matches('/'), path_template)
|
||||
} else {
|
||||
format!(
|
||||
"{}/{}",
|
||||
base_url.trim_end_matches('/'),
|
||||
path_template.trim_start_matches('/')
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Apply path + query replacements (safe no-ops if none exist)
|
||||
url = replace_path_params(&url, &[], body)
|
||||
.context("Failed to replace path params")?;
|
||||
url = append_query_params(&url, &[], body)
|
||||
.context("Failed to append query params")?;
|
||||
|
||||
// Load API keys (pre-provided Arc<Keys>)
|
||||
let (api_key, secret_key) = keys
|
||||
.get(api_name)
|
||||
.map(|k| (k.key.clone(), k.secret.clone()))
|
||||
.unwrap_or_else(|| (String::new(), String::new()));
|
||||
|
||||
// Parse HTTP method into reqwest::Method
|
||||
let method_obj =
|
||||
Method::from_bytes(method.as_bytes()).context("Invalid HTTP method string")?;
|
||||
|
||||
// Build request
|
||||
let mut req = client.request(method_obj.clone(), &url);
|
||||
|
||||
// Attach API credentials if present
|
||||
if !api_key.is_empty() {
|
||||
req = req
|
||||
.header("APCA-API-KEY-ID", api_key)
|
||||
.header("APCA-API-SECRET-KEY", secret_key);
|
||||
}
|
||||
|
||||
// ✅ Add custom headers if provided
|
||||
if let Some(extra_headers) = headers {
|
||||
for (k, v) in extra_headers {
|
||||
req = req.header(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach JSON body only for non-GET/HEAD methods
|
||||
if !(method_obj == Method::GET || method_obj == Method::HEAD) {
|
||||
req = req.json(body);
|
||||
}
|
||||
|
||||
// Send
|
||||
let resp = req.send().await.context("HTTP send failed")?;
|
||||
let status = resp.status();
|
||||
|
||||
if status.is_success() {
|
||||
let parsed: R = resp.json().await.context("Failed to parse JSON response")?;
|
||||
Ok(parsed)
|
||||
} else {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow!(
|
||||
"[{}] HTTP {} failed: {}",
|
||||
api_name,
|
||||
status,
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ConversionError {}
|
||||
unsafe impl Send for ConversionError {}
|
||||
unsafe impl Sync for ConversionError {}
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// TRAITS
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
pub trait HasHttp {
|
||||
fn http_methods() -> &'static [&'static str];
|
||||
fn live_url() -> &'static str;
|
||||
fn sandbox_url() -> &'static str;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Queryable: Send + Sync + 'static {
|
||||
type Response: Send + Sync + 'static;
|
||||
|
||||
/// Primary method: send the request using a provided `keys` set.
|
||||
async fn send_with_keys(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::Response>;
|
||||
|
||||
/// Convenience wrapper: loads keys automatically from disk.
|
||||
async fn send(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::Response> {
|
||||
let keys = load_api_keys().map_err(|e| anyhow!("Failed to load API keys: {}", e))?;
|
||||
self.send_with_keys(client, Arc::new(keys), override_url, sandbox, method_override, headers)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ApiDispatch: Send + Sync + 'static {
|
||||
type Request;
|
||||
type Response;
|
||||
type Error;
|
||||
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<Vec<anyhow::Result<Value>>>;
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// API ENDPOINT STRUCTURE + HELPERS
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiEndpoint {
|
||||
pub name: String,
|
||||
pub operation_id: Option<String>,
|
||||
pub method: String,
|
||||
pub path_template: String,
|
||||
pub path_params: Vec<String>,
|
||||
pub query_params: Vec<String>,
|
||||
pub live_url: String,
|
||||
pub sandbox_url: String,
|
||||
pub request_type: Option<String>,
|
||||
pub responses: Vec<(String, String)>,
|
||||
pub requires_auth: bool,
|
||||
}
|
||||
|
||||
pub fn replace_path_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
|
||||
let mut url = url.to_string();
|
||||
let val = serde_json::to_value(req)?;
|
||||
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in path param serialization"))?;
|
||||
|
||||
for p in params {
|
||||
if let Some(v) = obj.get(p) {
|
||||
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
|
||||
url = url.replace(&format!("{{{}}}", p), &s);
|
||||
}
|
||||
}
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
pub fn append_query_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
|
||||
if params.is_empty() {
|
||||
return Ok(url.to_string());
|
||||
}
|
||||
|
||||
let val = serde_json::to_value(req)?;
|
||||
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in query param serialization"))?;
|
||||
|
||||
let mut query_pairs = Vec::new();
|
||||
for p in params {
|
||||
if let Some(v) = obj.get(p) {
|
||||
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
|
||||
query_pairs.push(format!("{}={}", p, urlencoding::encode(&s)));
|
||||
}
|
||||
}
|
||||
|
||||
if query_pairs.is_empty() {
|
||||
Ok(url.to_string())
|
||||
} else if url.contains('?') {
|
||||
Ok(format!("{}&{}", url, query_pairs.join("&")))
|
||||
} else {
|
||||
Ok(format!("{}?{}", url, query_pairs.join("&")))
|
||||
}
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// BLANKET ADAPTER — make any Queryable act as an ApiDispatch
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
#[async_trait]
|
||||
impl<T> ApiDispatch for T
|
||||
where
|
||||
T: Queryable + Serialize + Send + Sync,
|
||||
T::Response: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Request = ();
|
||||
type Response = ();
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<Vec<anyhow::Result<Value>>> {
|
||||
let transport_res = self
|
||||
.send_with_keys(client, keys, override_url, sandbox, method_override, None)
|
||||
.await;
|
||||
|
||||
let res: anyhow::Result<Value> = match transport_res {
|
||||
Ok(v) => Ok(serde_json::to_value(v).unwrap_or(Value::Null)),
|
||||
Err(e) => Err(e),
|
||||
};
|
||||
|
||||
Ok(vec![res])
|
||||
}
|
||||
}
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// UNIVERSAL HTTP HELPER (for generated API structs)
|
||||
/// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn http_helper<T>(
|
||||
req: &T,
|
||||
client: Arc<Client>,
|
||||
keys: Arc<Keys>,
|
||||
override_url: Option<String>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> Result<T::Response>
|
||||
where
|
||||
T: Queryable + Serialize + Sync,
|
||||
T::Response: Send + Sync + 'static,
|
||||
{
|
||||
let override_str = override_url.as_deref();
|
||||
req.send_with_keys(client, keys, override_str, sandbox, method_override, None)
|
||||
.await
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user