ud
This commit is contained in:
164
src/lib.rs
164
src/lib.rs
@ -5,12 +5,12 @@ use quote::format_ident;
|
||||
use http_core::{Queryable, ApiDispatch, HasHttp, Keys};
|
||||
|
||||
#[proc_macro_derive(HttpRequest, attributes(http_response, http_error_type))]
|
||||
pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
|
||||
pub fn derive_http_request(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
let query_name = &input.ident;
|
||||
let query_name_str = query_name.to_string();
|
||||
|
||||
// Parse optional #[http_response = "..."] attribute
|
||||
// Parse optional #[http_response = "..."]
|
||||
let mut response_name_opt: Option<String> = None;
|
||||
for attr in &input.attrs {
|
||||
if attr.path().is_ident("http_response") {
|
||||
@ -22,15 +22,14 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}).unwrap_or_else(|e| panic!("Error parsing http_response attribute: {}", e));
|
||||
}).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let response_name_str = response_name_opt.unwrap_or_else(|| format!("{}Resp", query_name_str));
|
||||
let response_name = format_ident!("{}", response_name_str);
|
||||
|
||||
// Parse optional #[http_error_type = "..."] attribute (default to `E`)
|
||||
let mut error_type = syn::Path::from(syn::Ident::new("E", proc_macro2::Span::call_site()));
|
||||
// Parse optional #[http_error_type = "..."] (default to Box<dyn Error>)
|
||||
let mut error_type = syn::parse_str::<syn::Path>("Box<dyn std::error::Error>").unwrap();
|
||||
for attr in &input.attrs {
|
||||
if attr.path().is_ident("http_error_type") {
|
||||
attr.parse_nested_meta(|meta| {
|
||||
@ -41,11 +40,11 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}).unwrap_or_else(|e| panic!("Error parsing http_error_type attribute: {}", e));
|
||||
}).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Collect query parameters from lnk_p_* fields
|
||||
// Collect query parameters from fields prefixed with lnk_p_
|
||||
let query_param_code = if let Data::Struct(data_struct) = &input.data {
|
||||
if let Fields::Named(fields_named) = &data_struct.fields {
|
||||
fields_named.named.iter().filter_map(|field| {
|
||||
@ -58,51 +57,62 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
|
||||
query_params.push((#key.to_string(), val.to_string()));
|
||||
}
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else { None }
|
||||
}).collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
} else { Vec::new() }
|
||||
} else { Vec::new() };
|
||||
|
||||
let expanded = quote! {
|
||||
#[async_trait::async_trait]
|
||||
impl Queryable for #query_name {
|
||||
impl http_core::Queryable for #query_name {
|
||||
type R = #response_name;
|
||||
type E = #error_type;
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
base_url: &str,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::R, #error_type> {
|
||||
) -> Result<Self::R, Self::E> {
|
||||
use awc::Client;
|
||||
use urlencoding::encode;
|
||||
use http_core::HasHttp;
|
||||
|
||||
// collect lnk_p_* query params
|
||||
let mut query_params: Vec<(String, String)> = Vec::new();
|
||||
#(#query_param_code)*
|
||||
|
||||
let mut url = base_url.to_string();
|
||||
// pick URL
|
||||
let mut url = if let Some(u) = override_url {
|
||||
u.to_string()
|
||||
} else if sandbox {
|
||||
<Self as HasHttp>::sandbox_url().to_string()
|
||||
} else {
|
||||
<Self as HasHttp>::live_url().to_string()
|
||||
};
|
||||
|
||||
if !query_params.is_empty() {
|
||||
let mut query_string = String::new();
|
||||
let mut first = true;
|
||||
for (k, v) in &query_params {
|
||||
if !first {
|
||||
query_string.push('&');
|
||||
}
|
||||
first = false;
|
||||
query_string.push_str(&format!("{}={}", k, encode(v)));
|
||||
}
|
||||
let query_string = query_params.into_iter()
|
||||
.map(|(k,v)| format!("{}={}", k, encode(&v)))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
url.push('?');
|
||||
url.push_str(&query_string);
|
||||
}
|
||||
|
||||
// choose method
|
||||
let client = Client::default();
|
||||
let mut request = client.get(url);
|
||||
let mut request = match method_override.unwrap_or("GET") {
|
||||
"GET" => client.get(url),
|
||||
"POST" => client.post(url),
|
||||
"PUT" => client.put(url),
|
||||
"DELETE" => client.delete(url),
|
||||
"PATCH" => client.patch(url),
|
||||
m => panic!("Unsupported method override: {}", m),
|
||||
};
|
||||
|
||||
// add headers
|
||||
if let Some(hdrs) = headers {
|
||||
for (k, v) in hdrs {
|
||||
request = request.append_header((k, v));
|
||||
@ -120,7 +130,6 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream {
|
||||
TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let input_enum = parse_macro_input!(item as ItemEnum);
|
||||
@ -159,31 +168,19 @@ pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let client = Arc::new(awc::Client::default());
|
||||
let keys = Arc::new(crate::load_api_keys()?);
|
||||
|
||||
const THREADS: usize = 4;
|
||||
let total = queries.len();
|
||||
let per_thread = std::cmp::max(1, total / THREADS);
|
||||
let shared_queries = Arc::new(queries);
|
||||
|
||||
// Spawn all queries as async tasks
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..THREADS {
|
||||
let queries = Arc::clone(&shared_queries);
|
||||
for q in queries {
|
||||
let client = Arc::clone(&client);
|
||||
let keys = Arc::clone(&keys);
|
||||
let start = i * per_thread;
|
||||
let end = if i == THREADS - 1 { total } else { start + per_thread };
|
||||
|
||||
let handle = std::thread::spawn(move || {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
for q in &queries[start..end] {
|
||||
rt.block_on(q.send_all(&client, &keys)).unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
handles.push(tokio::spawn(async move {
|
||||
q.send_all(&client, &keys).await
|
||||
}));
|
||||
}
|
||||
|
||||
// Await all results and propagate first error (if any)
|
||||
for h in handles {
|
||||
h.join().expect("Thread panicked");
|
||||
h.await??;
|
||||
}
|
||||
}
|
||||
other => {
|
||||
@ -212,71 +209,6 @@ pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Trait for dispatching API calls
|
||||
pub trait ApiDispatch {
|
||||
fn send_all(
|
||||
&self,
|
||||
client: &awc::Client,
|
||||
keys: &std::collections::HashMap<String, crate::Keys>,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send>>;
|
||||
}
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn api_dispatch(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(item as syn::ItemEnum);
|
||||
let enum_ident = &input.ident;
|
||||
|
||||
// Parse attribute input: input = "MyQuery"
|
||||
let meta_args = attr.to_string();
|
||||
let input_type: syn::Ident = {
|
||||
let cleaned = meta_args.trim().replace("input", "").replace('=', "").replace('"', "").trim().to_string();
|
||||
syn::Ident::new(&cleaned, proc_macro2::Span::call_site())
|
||||
};
|
||||
|
||||
let expanded = quote! {
|
||||
#input
|
||||
|
||||
impl ApiDispatch for #enum_ident {
|
||||
fn send_all(
|
||||
&self,
|
||||
client: &awc::Client,
|
||||
keys: &std::collections::HashMap<String, crate::Keys>,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send>> {
|
||||
Box::pin(async move {
|
||||
match self {
|
||||
#enum_ident::Single { query } => {
|
||||
let parsed: #input_type = serde_json::from_str(query)?;
|
||||
parsed.send(client, keys).await?;
|
||||
}
|
||||
#enum_ident::Bulk { input } => {
|
||||
let json = if let Some(raw) = input {
|
||||
if std::path::Path::new(&raw).exists() {
|
||||
std::fs::read_to_string(raw)?
|
||||
} else {
|
||||
raw.clone()
|
||||
}
|
||||
} else {
|
||||
use std::io::Read;
|
||||
let mut buf = String::new();
|
||||
std::io::stdin().read_to_string(&mut buf)?;
|
||||
buf
|
||||
};
|
||||
|
||||
let items: Vec<#input_type> = serde_json::from_str(&json)?;
|
||||
for item in items {
|
||||
item.send(client, keys).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
|
||||
Reference in New Issue
Block a user