This commit is contained in:
2025-09-08 14:53:18 -04:00
parent fc96d91c77
commit 3043170396
2 changed files with 98 additions and 163 deletions

View File

@ -5,12 +5,12 @@ use quote::format_ident;
use http_core::{Queryable, ApiDispatch, HasHttp, Keys};
#[proc_macro_derive(HttpRequest, attributes(http_response, http_error_type))]
pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
pub fn derive_http_request(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let query_name = &input.ident;
let query_name_str = query_name.to_string();
// Parse optional #[http_response = "..."] attribute
// Parse optional #[http_response = "..."]
let mut response_name_opt: Option<String> = None;
for attr in &input.attrs {
if attr.path().is_ident("http_response") {
@ -22,15 +22,14 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
}
}
Ok(())
}).unwrap_or_else(|e| panic!("Error parsing http_response attribute: {}", e));
}).unwrap();
}
}
let response_name_str = response_name_opt.unwrap_or_else(|| format!("{}Resp", query_name_str));
let response_name = format_ident!("{}", response_name_str);
// Parse optional #[http_error_type = "..."] attribute (default to `E`)
let mut error_type = syn::Path::from(syn::Ident::new("E", proc_macro2::Span::call_site()));
// Parse optional #[http_error_type = "..."] (default to Box<dyn Error>)
let mut error_type = syn::parse_str::<syn::Path>("Box<dyn std::error::Error>").unwrap();
for attr in &input.attrs {
if attr.path().is_ident("http_error_type") {
attr.parse_nested_meta(|meta| {
@ -41,11 +40,11 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
}
}
Ok(())
}).unwrap_or_else(|e| panic!("Error parsing http_error_type attribute: {}", e));
}).unwrap();
}
}
// Collect query parameters from lnk_p_* fields
// Collect query parameters from fields prefixed with lnk_p_
let query_param_code = if let Data::Struct(data_struct) = &input.data {
if let Fields::Named(fields_named) = &data_struct.fields {
fields_named.named.iter().filter_map(|field| {
@ -58,51 +57,62 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
query_params.push((#key.to_string(), val.to_string()));
}
})
} else {
None
}
} else { None }
}).collect::<Vec<_>>()
} else {
Vec::new()
}
} else {
Vec::new()
};
} else { Vec::new() }
} else { Vec::new() };
let expanded = quote! {
#[async_trait::async_trait]
impl Queryable for #query_name {
impl http_core::Queryable for #query_name {
type R = #response_name;
type E = #error_type;
async fn send(
&self,
base_url: &str,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
headers: Option<Vec<(&str, &str)>>,
) -> Result<Self::R, #error_type> {
) -> Result<Self::R, Self::E> {
use awc::Client;
use urlencoding::encode;
use http_core::HasHttp;
// collect lnk_p_* query params
let mut query_params: Vec<(String, String)> = Vec::new();
#(#query_param_code)*
let mut url = base_url.to_string();
// pick URL
let mut url = if let Some(u) = override_url {
u.to_string()
} else if sandbox {
<Self as HasHttp>::sandbox_url().to_string()
} else {
<Self as HasHttp>::live_url().to_string()
};
if !query_params.is_empty() {
let mut query_string = String::new();
let mut first = true;
for (k, v) in &query_params {
if !first {
query_string.push('&');
}
first = false;
query_string.push_str(&format!("{}={}", k, encode(v)));
}
let query_string = query_params.into_iter()
.map(|(k,v)| format!("{}={}", k, encode(&v)))
.collect::<Vec<_>>()
.join("&");
url.push('?');
url.push_str(&query_string);
}
// choose method
let client = Client::default();
let mut request = client.get(url);
let mut request = match method_override.unwrap_or("GET") {
"GET" => client.get(url),
"POST" => client.post(url),
"PUT" => client.put(url),
"DELETE" => client.delete(url),
"PATCH" => client.patch(url),
m => panic!("Unsupported method override: {}", m),
};
// add headers
if let Some(hdrs) = headers {
for (k, v) in hdrs {
request = request.append_header((k, v));
@ -120,7 +130,6 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_enum = parse_macro_input!(item as ItemEnum);
@ -159,31 +168,19 @@ pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
let client = Arc::new(awc::Client::default());
let keys = Arc::new(crate::load_api_keys()?);
const THREADS: usize = 4;
let total = queries.len();
let per_thread = std::cmp::max(1, total / THREADS);
let shared_queries = Arc::new(queries);
// Spawn all queries as async tasks
let mut handles = Vec::new();
for i in 0..THREADS {
let queries = Arc::clone(&shared_queries);
for q in queries {
let client = Arc::clone(&client);
let keys = Arc::clone(&keys);
let start = i * per_thread;
let end = if i == THREADS - 1 { total } else { start + per_thread };
let handle = std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
for q in &queries[start..end] {
rt.block_on(q.send_all(&client, &keys)).unwrap();
}
});
handles.push(handle);
handles.push(tokio::spawn(async move {
q.send_all(&client, &keys).await
}));
}
// Await all results and propagate first error (if any)
for h in handles {
h.join().expect("Thread panicked");
h.await??;
}
}
other => {
@ -212,71 +209,6 @@ pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
}
Ok(())
}
// Trait for dispatching API calls
pub trait ApiDispatch {
fn send_all(
&self,
client: &awc::Client,
keys: &std::collections::HashMap<String, crate::Keys>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send>>;
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn api_dispatch(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as syn::ItemEnum);
let enum_ident = &input.ident;
// Parse attribute input: input = "MyQuery"
let meta_args = attr.to_string();
let input_type: syn::Ident = {
let cleaned = meta_args.trim().replace("input", "").replace('=', "").replace('"', "").trim().to_string();
syn::Ident::new(&cleaned, proc_macro2::Span::call_site())
};
let expanded = quote! {
#input
impl ApiDispatch for #enum_ident {
fn send_all(
&self,
client: &awc::Client,
keys: &std::collections::HashMap<String, crate::Keys>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send>> {
Box::pin(async move {
match self {
#enum_ident::Single { query } => {
let parsed: #input_type = serde_json::from_str(query)?;
parsed.send(client, keys).await?;
}
#enum_ident::Bulk { input } => {
let json = if let Some(raw) = input {
if std::path::Path::new(&raw).exists() {
std::fs::read_to_string(raw)?
} else {
raw.clone()
}
} else {
use std::io::Read;
let mut buf = String::new();
std::io::stdin().read_to_string(&mut buf)?;
buf
};
let items: Vec<#input_type> = serde_json::from_str(&json)?;
for item in items {
item.send(client, keys).await?;
}
}
}
Ok(())
})
}
}
};
TokenStream::from(expanded)