ud
This commit is contained in:
1605
Cargo.lock
generated
1605
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,8 +8,12 @@ license = "MIT OR Apache-2.0"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.100"
|
anyhow = "1.0.100"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
awc = "3"
|
reqwest = { version = "0.12.23" , features = ["json", "rustls-tls"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1.0.145"
|
||||||
toml = "0.9.5"
|
toml = "0.9.5"
|
||||||
urlencoding = "2"
|
urlencoding = "2"
|
||||||
|
tokio = { version = "1.47", features = ["full"] }
|
||||||
|
once_cell = "1.19"
|
||||||
|
num_cpus = "1.16"
|
||||||
|
|
||||||
|
|||||||
357
src/lib.rs
357
src/lib.rs
@ -1,56 +1,24 @@
|
|||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use anyhow::Result;
|
use once_cell::sync::Lazy;
|
||||||
use serde::{Deserialize, Serialize};
|
use num_cpus;
|
||||||
use std::collections::HashMap;
|
use reqwest::{Client, Method, header::HeaderMap };
|
||||||
use std::path::Path;
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use serde_json::Value;
|
||||||
use std::{fmt, fs};
|
use std::{collections::HashMap, fs, path::Path, sync::Arc};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
/// A trait every API request type should implement (via macro).
|
/// ---------------------------------------------------------------------------
|
||||||
/// Provides info about the HTTP protocol, URLs, and supported methods.
|
/// GLOBAL CONCURRENCY CONTROL
|
||||||
pub trait HasHttp {
|
/// ---------------------------------------------------------------------------
|
||||||
/// Methods this request supports (GET, POST, etc.)
|
|
||||||
fn http_methods() -> &'static [&'static str];
|
|
||||||
|
|
||||||
/// Production endpoint (base URL).
|
/// Shared semaphore to limit concurrent requests (2 × CPU threads).
|
||||||
fn live_url() -> &'static str;
|
static SEMAPHORE: Lazy<Arc<Semaphore>> =
|
||||||
|
Lazy::new(|| Arc::new(Semaphore::new(num_cpus::get() * 2)));
|
||||||
/// Sandbox endpoint (base URL).
|
|
||||||
fn sandbox_url() -> &'static str;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait implemented by all request types that can be sent to an HTTP API.
|
|
||||||
/// Usually derived via `#[http(...)]`.
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Queryable: HasHttp + Send + Sync {
|
|
||||||
type R;
|
|
||||||
type E: std::error::Error + Send + Sync + 'static; // Keep trait bound
|
|
||||||
|
|
||||||
async fn send(
|
|
||||||
&self,
|
|
||||||
override_url: Option<&str>,
|
|
||||||
sandbox: bool,
|
|
||||||
method_override: Option<&str>,
|
|
||||||
headers: Option<Vec<(&str, &str)>>,
|
|
||||||
) -> Result<Self::R>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait implemented by your auto-generated API dispatcher enums.
|
|
||||||
/// This allows you to batch or route requests by API type.
|
|
||||||
#[async_trait]
|
|
||||||
pub trait ApiDispatch {
|
|
||||||
type R;
|
|
||||||
|
|
||||||
async fn send_all(
|
|
||||||
&self,
|
|
||||||
client: Arc<awc::Client>,
|
|
||||||
keys: Arc<Keys>,
|
|
||||||
override_url: Option<&str>,
|
|
||||||
sandbox: bool,
|
|
||||||
method_override: Option<&str>,
|
|
||||||
) -> Result<Vec<Self::R>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
/// API KEY MANAGEMENT
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Key {
|
pub struct Key {
|
||||||
@ -64,17 +32,16 @@ pub struct Keys {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Keys {
|
impl Keys {
|
||||||
/// Convenience method to get a key by name (e.g. "alpaca").
|
|
||||||
pub fn get(&self, name: &str) -> Option<&Key> {
|
pub fn get(&self, name: &str) -> Option<&Key> {
|
||||||
self.keys.get(name)
|
self.keys.get(name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loads API keys from ~/.config/keys or fallback ./keys directory.
|
/// Load API keys from either:
|
||||||
/// Each `.toml` file should contain a `Key { key, secret }` struct.
|
/// - `$HOME/.config/keys`
|
||||||
/// The filename (without `.toml`) becomes the key name in the map.
|
/// - or local `./keys` directory
|
||||||
pub fn load_api_keys() -> Result<Keys> {
|
pub fn load_api_keys() -> Result<Keys> {
|
||||||
let home_dir = std::env::var("HOME")?;
|
let home_dir = std::env::var("HOME").context("Failed to read $HOME")?;
|
||||||
let config_dir = Path::new(&home_dir).join(".config/keys");
|
let config_dir = Path::new(&home_dir).join(".config/keys");
|
||||||
let fallback_dir = Path::new("keys");
|
let fallback_dir = Path::new("keys");
|
||||||
|
|
||||||
@ -86,14 +53,17 @@ pub fn load_api_keys() -> Result<Keys> {
|
|||||||
|
|
||||||
let mut keys_map = HashMap::new();
|
let mut keys_map = HashMap::new();
|
||||||
|
|
||||||
for entry in fs::read_dir(dir_path)? {
|
for entry in fs::read_dir(&dir_path)
|
||||||
let entry = entry?;
|
.with_context(|| format!("Failed to read keys directory: {}", dir_path.display()))?
|
||||||
|
{
|
||||||
|
let entry = entry.context("Failed to read directory entry")?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
|
|
||||||
if path.extension().and_then(|s| s.to_str()) == Some("toml") {
|
if path.extension().and_then(|s| s.to_str()) == Some("toml") {
|
||||||
let contents = fs::read_to_string(&path)?;
|
let contents = fs::read_to_string(&path)
|
||||||
let key: Key = toml::from_str(&contents)?;
|
.with_context(|| format!("Failed to read file: {}", path.display()))?;
|
||||||
|
let key: Key = toml::from_str(&contents)
|
||||||
|
.with_context(|| format!("Failed to parse TOML: {}", path.display()))?;
|
||||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||||
keys_map.insert(stem.to_string(), key);
|
keys_map.insert(stem.to_string(), key);
|
||||||
}
|
}
|
||||||
@ -103,15 +73,268 @@ pub fn load_api_keys() -> Result<Keys> {
|
|||||||
Ok(Keys { keys: keys_map })
|
Ok(Keys { keys: keys_map })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
/// ---------------------------------------------------------------------------
|
||||||
pub struct ConversionError(pub ::std::borrow::Cow<'static, str>);
|
/// CORE HTTP REQUEST HELPER
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
impl fmt::Display for ConversionError {
|
/// Shared helper to send HTTP requests with concurrency limiting, API key injection,
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
/// and optional custom headers.
|
||||||
write!(f, "{}", self.0)
|
pub async fn send_request<T, R>(
|
||||||
|
client: Arc<Client>,
|
||||||
|
keys: Arc<Keys>,
|
||||||
|
api_name: &str,
|
||||||
|
base_url: &str,
|
||||||
|
path_template: &str,
|
||||||
|
override_url: Option<&str>,
|
||||||
|
method: &str,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
body: &T,
|
||||||
|
) -> Result<R>
|
||||||
|
where
|
||||||
|
T: Serialize + Send + Sync,
|
||||||
|
R: DeserializeOwned + Send + 'static,
|
||||||
|
{
|
||||||
|
// Acquire semaphore permit (limits concurrency)
|
||||||
|
let _permit = SEMAPHORE.acquire().await.expect("Semaphore poisoned");
|
||||||
|
|
||||||
|
// Build initial URL: override or base + path_template
|
||||||
|
let mut url = if let Some(o) = override_url {
|
||||||
|
o.to_string()
|
||||||
|
} else {
|
||||||
|
if base_url.ends_with('/') {
|
||||||
|
format!("{}{}", base_url.trim_end_matches('/'), path_template)
|
||||||
|
} else {
|
||||||
|
format!(
|
||||||
|
"{}/{}",
|
||||||
|
base_url.trim_end_matches('/'),
|
||||||
|
path_template.trim_start_matches('/')
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply path + query replacements (safe no-ops if none exist)
|
||||||
|
url = replace_path_params(&url, &[], body)
|
||||||
|
.context("Failed to replace path params")?;
|
||||||
|
url = append_query_params(&url, &[], body)
|
||||||
|
.context("Failed to append query params")?;
|
||||||
|
|
||||||
|
// Load API keys (pre-provided Arc<Keys>)
|
||||||
|
let (api_key, secret_key) = keys
|
||||||
|
.get(api_name)
|
||||||
|
.map(|k| (k.key.clone(), k.secret.clone()))
|
||||||
|
.unwrap_or_else(|| (String::new(), String::new()));
|
||||||
|
|
||||||
|
// Parse HTTP method into reqwest::Method
|
||||||
|
let method_obj =
|
||||||
|
Method::from_bytes(method.as_bytes()).context("Invalid HTTP method string")?;
|
||||||
|
|
||||||
|
// Build request
|
||||||
|
let mut req = client.request(method_obj.clone(), &url);
|
||||||
|
|
||||||
|
// Attach API credentials if present
|
||||||
|
if !api_key.is_empty() {
|
||||||
|
req = req
|
||||||
|
.header("APCA-API-KEY-ID", api_key)
|
||||||
|
.header("APCA-API-SECRET-KEY", secret_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ Add custom headers if provided
|
||||||
|
if let Some(extra_headers) = headers {
|
||||||
|
for (k, v) in extra_headers {
|
||||||
|
req = req.header(k, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach JSON body only for non-GET/HEAD methods
|
||||||
|
if !(method_obj == Method::GET || method_obj == Method::HEAD) {
|
||||||
|
req = req.json(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send
|
||||||
|
let resp = req.send().await.context("HTTP send failed")?;
|
||||||
|
let status = resp.status();
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
let parsed: R = resp.json().await.context("Failed to parse JSON response")?;
|
||||||
|
Ok(parsed)
|
||||||
|
} else {
|
||||||
|
let text = resp.text().await.unwrap_or_default();
|
||||||
|
Err(anyhow!(
|
||||||
|
"[{}] HTTP {} failed: {}",
|
||||||
|
api_name,
|
||||||
|
status,
|
||||||
|
text
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for ConversionError {}
|
/// ---------------------------------------------------------------------------
|
||||||
unsafe impl Send for ConversionError {}
|
/// TRAITS
|
||||||
unsafe impl Sync for ConversionError {}
|
/// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
pub trait HasHttp {
|
||||||
|
fn http_methods() -> &'static [&'static str];
|
||||||
|
fn live_url() -> &'static str;
|
||||||
|
fn sandbox_url() -> &'static str;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Queryable: Send + Sync + 'static {
|
||||||
|
type Response: Send + Sync + 'static;
|
||||||
|
|
||||||
|
/// Primary method: send the request using a provided `keys` set.
|
||||||
|
async fn send_with_keys(
|
||||||
|
&self,
|
||||||
|
client: Arc<Client>,
|
||||||
|
keys: Arc<Keys>,
|
||||||
|
override_url: Option<&str>,
|
||||||
|
sandbox: bool,
|
||||||
|
method_override: Option<&str>,
|
||||||
|
headers: Option<Vec<(&str, &str)>>,
|
||||||
|
) -> Result<Self::Response>;
|
||||||
|
|
||||||
|
/// Convenience wrapper: loads keys automatically from disk.
|
||||||
|
async fn send(
|
||||||
|
&self,
|
||||||
|
client: Arc<Client>,
|
||||||
|
override_url: Option<&str>,
|
||||||
|
sandbox: bool,
|
||||||
|
method_override: Option<&str>,
|
||||||
|
headers: Option<Vec<(&str, &str)>>,
|
||||||
|
) -> Result<Self::Response> {
|
||||||
|
let keys = load_api_keys().map_err(|e| anyhow!("Failed to load API keys: {}", e))?;
|
||||||
|
self.send_with_keys(client, Arc::new(keys), override_url, sandbox, method_override, headers)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ApiDispatch: Send + Sync + 'static {
|
||||||
|
type Request;
|
||||||
|
type Response;
|
||||||
|
type Error;
|
||||||
|
|
||||||
|
async fn send_all(
|
||||||
|
&self,
|
||||||
|
client: Arc<Client>,
|
||||||
|
keys: Arc<Keys>,
|
||||||
|
override_url: Option<&str>,
|
||||||
|
sandbox: bool,
|
||||||
|
method_override: Option<&str>,
|
||||||
|
) -> Result<Vec<anyhow::Result<Value>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
/// API ENDPOINT STRUCTURE + HELPERS
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ApiEndpoint {
|
||||||
|
pub name: String,
|
||||||
|
pub operation_id: Option<String>,
|
||||||
|
pub method: String,
|
||||||
|
pub path_template: String,
|
||||||
|
pub path_params: Vec<String>,
|
||||||
|
pub query_params: Vec<String>,
|
||||||
|
pub live_url: String,
|
||||||
|
pub sandbox_url: String,
|
||||||
|
pub request_type: Option<String>,
|
||||||
|
pub responses: Vec<(String, String)>,
|
||||||
|
pub requires_auth: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn replace_path_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
|
||||||
|
let mut url = url.to_string();
|
||||||
|
let val = serde_json::to_value(req)?;
|
||||||
|
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in path param serialization"))?;
|
||||||
|
|
||||||
|
for p in params {
|
||||||
|
if let Some(v) = obj.get(p) {
|
||||||
|
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
|
||||||
|
url = url.replace(&format!("{{{}}}", p), &s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append_query_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> {
|
||||||
|
if params.is_empty() {
|
||||||
|
return Ok(url.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let val = serde_json::to_value(req)?;
|
||||||
|
let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in query param serialization"))?;
|
||||||
|
|
||||||
|
let mut query_pairs = Vec::new();
|
||||||
|
for p in params {
|
||||||
|
if let Some(v) = obj.get(p) {
|
||||||
|
let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string());
|
||||||
|
query_pairs.push(format!("{}={}", p, urlencoding::encode(&s)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if query_pairs.is_empty() {
|
||||||
|
Ok(url.to_string())
|
||||||
|
} else if url.contains('?') {
|
||||||
|
Ok(format!("{}&{}", url, query_pairs.join("&")))
|
||||||
|
} else {
|
||||||
|
Ok(format!("{}?{}", url, query_pairs.join("&")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
/// BLANKET ADAPTER — make any Queryable act as an ApiDispatch
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T> ApiDispatch for T
|
||||||
|
where
|
||||||
|
T: Queryable + Serialize + Send + Sync,
|
||||||
|
T::Response: Serialize + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type Request = ();
|
||||||
|
type Response = ();
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
async fn send_all(
|
||||||
|
&self,
|
||||||
|
client: Arc<Client>,
|
||||||
|
keys: Arc<Keys>,
|
||||||
|
override_url: Option<&str>,
|
||||||
|
sandbox: bool,
|
||||||
|
method_override: Option<&str>,
|
||||||
|
) -> Result<Vec<anyhow::Result<Value>>> {
|
||||||
|
let transport_res = self
|
||||||
|
.send_with_keys(client, keys, override_url, sandbox, method_override, None)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let res: anyhow::Result<Value> = match transport_res {
|
||||||
|
Ok(v) => Ok(serde_json::to_value(v).unwrap_or(Value::Null)),
|
||||||
|
Err(e) => Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(vec![res])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
/// UNIVERSAL HTTP HELPER (for generated API structs)
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
pub async fn http_helper<T>(
|
||||||
|
req: &T,
|
||||||
|
client: Arc<Client>,
|
||||||
|
keys: Arc<Keys>,
|
||||||
|
override_url: Option<String>,
|
||||||
|
sandbox: bool,
|
||||||
|
method_override: Option<&str>,
|
||||||
|
) -> Result<T::Response>
|
||||||
|
where
|
||||||
|
T: Queryable + Serialize + Sync,
|
||||||
|
T::Response: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let override_str = override_url.as_deref();
|
||||||
|
req.send_with_keys(client, keys, override_str, sandbox, method_override, None)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user