hopefully the dispatch macro is properly multithreaded now
This commit is contained in:
1398
Cargo.lock
generated
1398
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -12,3 +12,4 @@ proc-macro2 = "1.0.95"
|
||||
quote = "1.0.40"
|
||||
syn = { version = "2.0.104", features = ["full"] }
|
||||
http_core = { path = "../http_core" }
|
||||
anyhow = "1.0.100"
|
||||
|
||||
332
src/lib.rs
332
src/lib.rs
@ -1,22 +1,22 @@
|
||||
// src/lib.rs -- proc-macro crate
|
||||
|
||||
use async_trait::async_trait;
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::Span;
|
||||
use quote::quote;
|
||||
use syn::{
|
||||
parse_macro_input, ItemStruct, ItemEnum, Fields, Type, Meta, Lit, Expr,
|
||||
punctuated::Punctuated, token::Comma, MetaNameValue,
|
||||
// Important: Parser trait brings `parse2` into scope for parser functions
|
||||
parse::Parser,
|
||||
parse_macro_input, Fields, ItemStruct, Lit, Meta, MetaNameValue, Expr, Type,
|
||||
punctuated::Punctuated, token::Comma, parse::Parser,
|
||||
};
|
||||
|
||||
/// ---------------------------------------------------------------------------
|
||||
/// #[http(...)] attribute macro
|
||||
/// Accepts (with defaults):
|
||||
/// Generates HasHttp + Queryable impls (reqwest-based).
|
||||
///
|
||||
/// Defaults:
|
||||
/// - method = "GET"
|
||||
/// - live_url = ""
|
||||
/// - sandbox_url = ""
|
||||
/// - response = "<StructName>Resp"
|
||||
/// - error = "Box<dyn std::error::Error>"
|
||||
/// ---------------------------------------------------------------------------
|
||||
#[proc_macro_attribute]
|
||||
pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(item as ItemStruct);
|
||||
@ -27,33 +27,20 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let mut live_url_s = "".to_string();
|
||||
let mut sandbox_url_s = "".to_string();
|
||||
let mut response_s = format!("{}Resp", struct_ident);
|
||||
let mut error_s = "Box<dyn std::error::Error + Send + Sync>".to_string();
|
||||
|
||||
// turn attr into TokenStream2
|
||||
// parse #[http(...)] arguments
|
||||
let attr_ts: proc_macro2::TokenStream = attr.into();
|
||||
|
||||
if !attr_ts.is_empty() {
|
||||
if let Ok(meta) = syn::parse2::<Meta>(attr_ts.clone()) {
|
||||
if let Meta::List(list) = meta {
|
||||
let nested: Punctuated<MetaNameValue, Comma> =
|
||||
Punctuated::parse_terminated.parse2(list.tokens)
|
||||
Punctuated::parse_terminated
|
||||
.parse2(list.tokens)
|
||||
.expect("failed to parse http attribute list");
|
||||
for nv in nested {
|
||||
if let Some(ident) = nv.path.get_ident() {
|
||||
let key = ident.to_string();
|
||||
if key == "error" {
|
||||
match &nv.value {
|
||||
Expr::Lit(expr_lit) => {
|
||||
if let Lit::Str(ls) = &expr_lit.lit {
|
||||
error_s = ls.value();
|
||||
}
|
||||
}
|
||||
Expr::Path(p) => {
|
||||
error_s = quote!(#p).to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if let Expr::Lit(expr_lit) = &nv.value {
|
||||
if let Expr::Lit(expr_lit) = &nv.value {
|
||||
if let Lit::Str(ls) = &expr_lit.lit {
|
||||
let val = ls.value();
|
||||
match key.as_str() {
|
||||
@ -71,7 +58,7 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
}
|
||||
}
|
||||
|
||||
// collect lnk_p_* -> query
|
||||
// collect query param fields (lnk_p_*)
|
||||
let mut qparam_snippets = Vec::new();
|
||||
if let Fields::Named(fields_named) = &input.fields {
|
||||
for field in &fields_named.named {
|
||||
@ -88,16 +75,15 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
}
|
||||
}
|
||||
|
||||
// parse response & error types
|
||||
// response type (default = serde_json::Value)
|
||||
let response_ty: Type = syn::parse_str(&response_s)
|
||||
.unwrap_or_else(|_| syn::parse_str("serde_json::Value").unwrap());
|
||||
let error_ty: Type = syn::parse_str(&error_s)
|
||||
.unwrap_or_else(|_| syn::parse_str("Box<dyn std::error::Error + Send + Sync>").unwrap());
|
||||
|
||||
let method_lit = syn::LitStr::new(&method_s, Span::call_site());
|
||||
let live_lit = syn::LitStr::new(&live_url_s, Span::call_site());
|
||||
let sandbox_lit = syn::LitStr::new(&sandbox_url_s, Span::call_site());
|
||||
|
||||
// ✅ Final merged + polished implementation
|
||||
let expanded = quote! {
|
||||
#input
|
||||
|
||||
@ -107,60 +93,74 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
fn sandbox_url() -> &'static str { #sandbox_lit }
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[async_trait]
|
||||
impl http_core::Queryable for #struct_ident {
|
||||
type R = #response_ty;
|
||||
type E = #error_ty;
|
||||
type Response = #response_ty;
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
client: std::sync::Arc<reqwest::Client>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
headers: Option<Vec<(&str, &str)>>,
|
||||
) -> Result<Self::R, Self::E> {
|
||||
use awc::Client;
|
||||
use http_core::HasHttp;
|
||||
use urlencoding::encode;
|
||||
) -> anyhow::Result<Self::Response> {
|
||||
use http_core::{HasHttp, replace_path_params, append_query_params};
|
||||
use anyhow::{Context, anyhow};
|
||||
use serde_json::to_vec;
|
||||
|
||||
let mut query_params: Vec<(String, String)> = Vec::new();
|
||||
#(#qparam_snippets)*
|
||||
|
||||
let mut url = if let Some(u) = override_url {
|
||||
u.to_string()
|
||||
} else if sandbox {
|
||||
<Self as HasHttp>::sandbox_url().to_string()
|
||||
} else {
|
||||
<Self as HasHttp>::live_url().to_string()
|
||||
};
|
||||
// base URL resolution
|
||||
let mut url = override_url.unwrap_or_else(|| {
|
||||
if sandbox {
|
||||
<Self as HasHttp>::sandbox_url()
|
||||
} else {
|
||||
<Self as HasHttp>::live_url()
|
||||
}
|
||||
}).to_string();
|
||||
|
||||
if !query_params.is_empty() {
|
||||
let qs = query_params.into_iter()
|
||||
.map(|(k,v)| format!("{}={}", k, encode(&v)))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
url.push('?');
|
||||
url.push_str(&qs);
|
||||
}
|
||||
// replace path params if any exist in the template
|
||||
url = replace_path_params(&url, &[], self)
|
||||
.context("Failed to replace path params")?;
|
||||
|
||||
let method = method_override.unwrap_or(#method_lit);
|
||||
let client = Client::default();
|
||||
let mut req = match method {
|
||||
"GET" => client.get(url.clone()),
|
||||
"POST" => client.post(url.clone()),
|
||||
"PUT" => client.put(url.clone()),
|
||||
"DELETE" => client.delete(url.clone()),
|
||||
"PATCH" => client.patch(url.clone()),
|
||||
_ => client.get(url.clone()),
|
||||
};
|
||||
// append query params if present
|
||||
url = append_query_params(&url, &query_params.iter().map(|(k, _)| k.clone()).collect::<Vec<_>>(), self)
|
||||
.context("Failed to append query params")?;
|
||||
|
||||
let method_str = method_override.unwrap_or(#method_lit);
|
||||
let method = reqwest::Method::from_bytes(method_str.as_bytes())
|
||||
.unwrap_or(reqwest::Method::GET);
|
||||
|
||||
// build request
|
||||
let mut req = client.request(method.clone(), &url);
|
||||
|
||||
if let Some(hs) = headers {
|
||||
for (k, v) in hs { req = req.append_header((k, v)); }
|
||||
for (k, v) in hs {
|
||||
req = req.header(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
let resp = req.send().await.map_err(Into::into)?;
|
||||
let body = resp.body().await.map_err(Into::into)?;
|
||||
let parsed: Self::R = serde_json::from_slice(&body).map_err(Into::into)?;
|
||||
// send body if applicable
|
||||
let resp = if matches!(method, reqwest::Method::POST | reqwest::Method::PUT | reqwest::Method::PATCH) {
|
||||
let body = to_vec(self).context("Failed to serialize request body")?;
|
||||
req = req.body(body);
|
||||
req.send().await.context("HTTP request failed")?
|
||||
} else {
|
||||
req.send().await.context("HTTP request failed")?
|
||||
};
|
||||
|
||||
// handle response
|
||||
let status = resp.status();
|
||||
let bytes = resp.bytes().await.context("Failed to read response body")?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(anyhow!("HTTP {}: {}", status, String::from_utf8_lossy(&bytes)));
|
||||
}
|
||||
|
||||
let parsed: Self::Response = serde_json::from_slice(&bytes)
|
||||
.context("Failed to deserialize response")?;
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
@ -169,86 +169,136 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
TokenStream::from(expanded)
|
||||
}
|
||||
|
||||
/// CLI generator macro (keeps your existing CLI structure)
|
||||
#[proc_macro_attribute]
|
||||
pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let input_enum = parse_macro_input!(item as ItemEnum);
|
||||
let top_enum_ident = &input_enum.ident;
|
||||
let top_variants = &input_enum.variants;
|
||||
#[proc_macro]
|
||||
pub fn dispatch(input: TokenStream) -> TokenStream {
|
||||
use quote::format_ident;
|
||||
let input = parse_macro_input!(input as syn::ExprTuple);
|
||||
|
||||
// Build outer match arms same as your previous implementation
|
||||
let match_arms: Vec<_> = top_variants.iter().map(|variant| {
|
||||
let variant_ident = &variant.ident;
|
||||
if input.elems.len() != 2 {
|
||||
return syn::Error::new_spanned(input, "Expected (TopEnumIdent, hashmap!{ ... })")
|
||||
.to_compile_error()
|
||||
.into();
|
||||
}
|
||||
|
||||
// Expect tuple variant like Alpaca(AlpacaCmd)
|
||||
let inner_type = match &variant.fields {
|
||||
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
||||
match &fields.unnamed.first().unwrap().ty {
|
||||
syn::Type::Path(p) => p.path.segments.last().unwrap().ident.clone(),
|
||||
_ => panic!("Expected tuple variant with a type path"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Each variant must be a tuple variant like `Alpaca(AlpacaCmd)`"),
|
||||
};
|
||||
|
||||
quote! {
|
||||
#top_enum_ident::#variant_ident(inner) => {
|
||||
match inner {
|
||||
#inner_type::Bulk { input } => {
|
||||
let mut reader: Box<dyn std::io::Read> = match input {
|
||||
Some(path) => Box::new(std::fs::File::open(path)?),
|
||||
None => Box::new(std::io::stdin()),
|
||||
};
|
||||
|
||||
let mut buf = String::new();
|
||||
reader.read_to_string(&mut buf)?;
|
||||
let queries: Vec<#inner_type> = serde_json::from_str(&buf)?;
|
||||
|
||||
use std::sync::Arc;
|
||||
let client = Arc::new(awc::Client::default());
|
||||
let keys = Arc::new(crate::load_api_keys()?);
|
||||
|
||||
// Spawn tasks for each query and await
|
||||
let mut handles = Vec::new();
|
||||
for q in queries {
|
||||
let client = Arc::clone(&client);
|
||||
let keys = Arc::clone(&keys);
|
||||
handles.push(tokio::spawn(async move {
|
||||
q.send_all(&client, &keys, None, false, None).await
|
||||
}));
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.await??;
|
||||
}
|
||||
}
|
||||
other => {
|
||||
let client = awc::Client::default();
|
||||
let keys = crate::load_api_keys()?;
|
||||
other.send_all(&client, &keys, None, false, None).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}).collect();
|
||||
|
||||
let expanded = quote! {
|
||||
use clap::Parser;
|
||||
use std::io::Read;
|
||||
|
||||
#[derive(clap::Parser, Debug)]
|
||||
#input_enum
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cmd = #top_enum_ident::parse();
|
||||
match cmd {
|
||||
#(#match_arms),*
|
||||
}
|
||||
Ok(())
|
||||
// Extract top enum name
|
||||
let top_enum_ident = match &input.elems[0] {
|
||||
syn::Expr::Path(p) => &p.path.segments.last().unwrap().ident,
|
||||
other => {
|
||||
return syn::Error::new_spanned(other, "First element must be a type path").to_compile_error().into();
|
||||
}
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
// Parse second argument (the hashmap! macro)
|
||||
let map_expr = &input.elems[1];
|
||||
let map_macro = match map_expr {
|
||||
syn::Expr::Macro(m) if m.mac.path.is_ident("hashmap") => m,
|
||||
other => {
|
||||
return syn::Error::new_spanned(other, "Second element must be a hashmap! macro").to_compile_error().into();
|
||||
}
|
||||
};
|
||||
|
||||
let tokens: proc_macro2::TokenStream = map_macro.mac.tokens.clone();
|
||||
let parser = syn::parse2::<syn::ExprArray>(tokens);
|
||||
|
||||
let pairs: Vec<(syn::Ident, Vec<(syn::Ident, syn::Path)>)> = if let Ok(array) = parser {
|
||||
array.elems.into_iter().filter_map(|expr| {
|
||||
if let syn::Expr::Tuple(tuple) = expr {
|
||||
if tuple.elems.len() == 2 {
|
||||
if let (syn::Expr::Path(api_ident), syn::Expr::Array(cmds)) =
|
||||
(&tuple.elems[0], &tuple.elems[1])
|
||||
{
|
||||
let api_ident = api_ident.path.segments.last().unwrap().ident.clone();
|
||||
let mut subcmds = Vec::new();
|
||||
for cmd_expr in &cmds.elems {
|
||||
if let syn::Expr::Tuple(cmd_pair) = cmd_expr {
|
||||
if cmd_pair.elems.len() == 2 {
|
||||
if let (syn::Expr::Path(cmd_ident), syn::Expr::Path(path)) =
|
||||
(&cmd_pair.elems[0], &cmd_pair.elems[1])
|
||||
{
|
||||
subcmds.push((
|
||||
cmd_ident.path.segments.last().unwrap().ident.clone(),
|
||||
path.path.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Some((api_ident, subcmds));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}).collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Generate async match arms for each API
|
||||
let mut match_arms = Vec::new();
|
||||
for (api_ident, subcmds) in pairs {
|
||||
for (cmd_ident, path) in subcmds {
|
||||
match_arms.push(quote! {
|
||||
#top_enum_ident::#api_ident(inner) => {
|
||||
match inner {
|
||||
#path::#cmd_ident(reqs) => {
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
// Limit concurrency to 10 by default
|
||||
let semaphore = Arc::new(Semaphore::new(10));
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for req in reqs {
|
||||
let client = client.clone();
|
||||
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||
let override_url = override_url.cloned();
|
||||
let method_override = method_override.cloned();
|
||||
let keys = keys.clone();
|
||||
let sandbox = sandbox;
|
||||
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let _permit = permit;
|
||||
req.send(client, override_url.as_deref(), sandbox, method_override.as_deref(), None)
|
||||
.await
|
||||
.map(|r| serde_json::to_value(r).unwrap_or(serde_json::Value::Null))
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(res) = tasks.next().await {
|
||||
results.push(res.unwrap_or_else(|e| Err(anyhow::anyhow!("JoinError: {}", e))));
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let expanded = quote! {
|
||||
#[async_trait]
|
||||
impl http_core::ApiDispatch for #top_enum_ident {
|
||||
type Request = ();
|
||||
type Response = ();
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn send_all(
|
||||
&self,
|
||||
client: std::sync::Arc<reqwest::Client>,
|
||||
keys: std::sync::Arc<http_core::Keys>,
|
||||
override_url: Option<&str>,
|
||||
sandbox: bool,
|
||||
method_override: Option<&str>,
|
||||
) -> anyhow::Result<Vec<anyhow::Result<serde_json::Value>>> {
|
||||
match self {
|
||||
#(#match_arms),*
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
expanded.into()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user