hopefully the dispatch macro is properly multithreaded now

This commit is contained in:
2025-10-10 11:36:04 -04:00
parent cb25899252
commit ba2cde601d
3 changed files with 1087 additions and 644 deletions

1398
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,4 @@ proc-macro2 = "1.0.95"
quote = "1.0.40"
syn = { version = "2.0.104", features = ["full"] }
http_core = { path = "../http_core" }
anyhow = "1.0.100"

View File

@ -1,22 +1,22 @@
// src/lib.rs -- proc-macro crate
use async_trait::async_trait;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{
parse_macro_input, ItemStruct, ItemEnum, Fields, Type, Meta, Lit, Expr,
punctuated::Punctuated, token::Comma, MetaNameValue,
// Important: Parser trait brings `parse2` into scope for parser functions
parse::Parser,
parse_macro_input, Fields, ItemStruct, Lit, Meta, MetaNameValue, Expr, Type,
punctuated::Punctuated, token::Comma, parse::Parser,
};
/// ---------------------------------------------------------------------------
/// #[http(...)] attribute macro
/// Accepts (with defaults):
/// Generates HasHttp + Queryable impls (reqwest-based).
///
/// Defaults:
/// - method = "GET"
/// - live_url = ""
/// - sandbox_url = ""
/// - response = "<StructName>Resp"
/// - error = "Box<dyn std::error::Error>"
/// ---------------------------------------------------------------------------
#[proc_macro_attribute]
pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
@ -27,33 +27,20 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut live_url_s = "".to_string();
let mut sandbox_url_s = "".to_string();
let mut response_s = format!("{}Resp", struct_ident);
let mut error_s = "Box<dyn std::error::Error + Send + Sync>".to_string();
// turn attr into TokenStream2
// parse #[http(...)] arguments
let attr_ts: proc_macro2::TokenStream = attr.into();
if !attr_ts.is_empty() {
if let Ok(meta) = syn::parse2::<Meta>(attr_ts.clone()) {
if let Meta::List(list) = meta {
let nested: Punctuated<MetaNameValue, Comma> =
Punctuated::parse_terminated.parse2(list.tokens)
Punctuated::parse_terminated
.parse2(list.tokens)
.expect("failed to parse http attribute list");
for nv in nested {
if let Some(ident) = nv.path.get_ident() {
let key = ident.to_string();
if key == "error" {
match &nv.value {
Expr::Lit(expr_lit) => {
if let Lit::Str(ls) = &expr_lit.lit {
error_s = ls.value();
}
}
Expr::Path(p) => {
error_s = quote!(#p).to_string();
}
_ => {}
}
} else if let Expr::Lit(expr_lit) = &nv.value {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(ls) = &expr_lit.lit {
let val = ls.value();
match key.as_str() {
@ -71,7 +58,7 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
// collect lnk_p_* -> query
// collect query param fields (lnk_p_*)
let mut qparam_snippets = Vec::new();
if let Fields::Named(fields_named) = &input.fields {
for field in &fields_named.named {
@ -88,16 +75,15 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
// parse response & error types
// response type (default = serde_json::Value)
let response_ty: Type = syn::parse_str(&response_s)
.unwrap_or_else(|_| syn::parse_str("serde_json::Value").unwrap());
let error_ty: Type = syn::parse_str(&error_s)
.unwrap_or_else(|_| syn::parse_str("Box<dyn std::error::Error + Send + Sync>").unwrap());
let method_lit = syn::LitStr::new(&method_s, Span::call_site());
let live_lit = syn::LitStr::new(&live_url_s, Span::call_site());
let sandbox_lit = syn::LitStr::new(&sandbox_url_s, Span::call_site());
// ✅ Final merged + polished implementation
let expanded = quote! {
#input
@ -107,60 +93,74 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
fn sandbox_url() -> &'static str { #sandbox_lit }
}
#[async_trait::async_trait]
#[async_trait]
impl http_core::Queryable for #struct_ident {
type R = #response_ty;
type E = #error_ty;
type Response = #response_ty;
async fn send(
&self,
client: std::sync::Arc<reqwest::Client>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
headers: Option<Vec<(&str, &str)>>,
) -> Result<Self::R, Self::E> {
use awc::Client;
use http_core::HasHttp;
use urlencoding::encode;
) -> anyhow::Result<Self::Response> {
use http_core::{HasHttp, replace_path_params, append_query_params};
use anyhow::{Context, anyhow};
use serde_json::to_vec;
let mut query_params: Vec<(String, String)> = Vec::new();
#(#qparam_snippets)*
let mut url = if let Some(u) = override_url {
u.to_string()
} else if sandbox {
<Self as HasHttp>::sandbox_url().to_string()
} else {
<Self as HasHttp>::live_url().to_string()
};
// base URL resolution
let mut url = override_url.unwrap_or_else(|| {
if sandbox {
<Self as HasHttp>::sandbox_url()
} else {
<Self as HasHttp>::live_url()
}
}).to_string();
if !query_params.is_empty() {
let qs = query_params.into_iter()
.map(|(k,v)| format!("{}={}", k, encode(&v)))
.collect::<Vec<_>>()
.join("&");
url.push('?');
url.push_str(&qs);
}
// replace path params if any exist in the template
url = replace_path_params(&url, &[], self)
.context("Failed to replace path params")?;
let method = method_override.unwrap_or(#method_lit);
let client = Client::default();
let mut req = match method {
"GET" => client.get(url.clone()),
"POST" => client.post(url.clone()),
"PUT" => client.put(url.clone()),
"DELETE" => client.delete(url.clone()),
"PATCH" => client.patch(url.clone()),
_ => client.get(url.clone()),
};
// append query params if present
url = append_query_params(&url, &query_params.iter().map(|(k, _)| k.clone()).collect::<Vec<_>>(), self)
.context("Failed to append query params")?;
let method_str = method_override.unwrap_or(#method_lit);
let method = reqwest::Method::from_bytes(method_str.as_bytes())
.unwrap_or(reqwest::Method::GET);
// build request
let mut req = client.request(method.clone(), &url);
if let Some(hs) = headers {
for (k, v) in hs { req = req.append_header((k, v)); }
for (k, v) in hs {
req = req.header(k, v);
}
}
let resp = req.send().await.map_err(Into::into)?;
let body = resp.body().await.map_err(Into::into)?;
let parsed: Self::R = serde_json::from_slice(&body).map_err(Into::into)?;
// send body if applicable
let resp = if matches!(method, reqwest::Method::POST | reqwest::Method::PUT | reqwest::Method::PATCH) {
let body = to_vec(self).context("Failed to serialize request body")?;
req = req.body(body);
req.send().await.context("HTTP request failed")?
} else {
req.send().await.context("HTTP request failed")?
};
// handle response
let status = resp.status();
let bytes = resp.bytes().await.context("Failed to read response body")?;
if !status.is_success() {
return Err(anyhow!("HTTP {}: {}", status, String::from_utf8_lossy(&bytes)));
}
let parsed: Self::Response = serde_json::from_slice(&bytes)
.context("Failed to deserialize response")?;
Ok(parsed)
}
}
@ -169,86 +169,136 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
TokenStream::from(expanded)
}
/// CLI generator macro (keeps your existing CLI structure)
#[proc_macro_attribute]
pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_enum = parse_macro_input!(item as ItemEnum);
let top_enum_ident = &input_enum.ident;
let top_variants = &input_enum.variants;
#[proc_macro]
pub fn dispatch(input: TokenStream) -> TokenStream {
use quote::format_ident;
let input = parse_macro_input!(input as syn::ExprTuple);
// Build outer match arms same as your previous implementation
let match_arms: Vec<_> = top_variants.iter().map(|variant| {
let variant_ident = &variant.ident;
if input.elems.len() != 2 {
return syn::Error::new_spanned(input, "Expected (TopEnumIdent, hashmap!{ ... })")
.to_compile_error()
.into();
}
// Expect tuple variant like Alpaca(AlpacaCmd)
let inner_type = match &variant.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
match &fields.unnamed.first().unwrap().ty {
syn::Type::Path(p) => p.path.segments.last().unwrap().ident.clone(),
_ => panic!("Expected tuple variant with a type path"),
}
}
_ => panic!("Each variant must be a tuple variant like `Alpaca(AlpacaCmd)`"),
};
quote! {
#top_enum_ident::#variant_ident(inner) => {
match inner {
#inner_type::Bulk { input } => {
let mut reader: Box<dyn std::io::Read> = match input {
Some(path) => Box::new(std::fs::File::open(path)?),
None => Box::new(std::io::stdin()),
};
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
let queries: Vec<#inner_type> = serde_json::from_str(&buf)?;
use std::sync::Arc;
let client = Arc::new(awc::Client::default());
let keys = Arc::new(crate::load_api_keys()?);
// Spawn tasks for each query and await
let mut handles = Vec::new();
for q in queries {
let client = Arc::clone(&client);
let keys = Arc::clone(&keys);
handles.push(tokio::spawn(async move {
q.send_all(&client, &keys, None, false, None).await
}));
}
for h in handles {
h.await??;
}
}
other => {
let client = awc::Client::default();
let keys = crate::load_api_keys()?;
other.send_all(&client, &keys, None, false, None).await?;
}
}
}
}
}).collect();
let expanded = quote! {
use clap::Parser;
use std::io::Read;
#[derive(clap::Parser, Debug)]
#input_enum
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cmd = #top_enum_ident::parse();
match cmd {
#(#match_arms),*
}
Ok(())
// Extract top enum name
let top_enum_ident = match &input.elems[0] {
syn::Expr::Path(p) => &p.path.segments.last().unwrap().ident,
other => {
return syn::Error::new_spanned(other, "First element must be a type path").to_compile_error().into();
}
};
TokenStream::from(expanded)
// Parse second argument (the hashmap! macro)
let map_expr = &input.elems[1];
let map_macro = match map_expr {
syn::Expr::Macro(m) if m.mac.path.is_ident("hashmap") => m,
other => {
return syn::Error::new_spanned(other, "Second element must be a hashmap! macro").to_compile_error().into();
}
};
let tokens: proc_macro2::TokenStream = map_macro.mac.tokens.clone();
let parser = syn::parse2::<syn::ExprArray>(tokens);
let pairs: Vec<(syn::Ident, Vec<(syn::Ident, syn::Path)>)> = if let Ok(array) = parser {
array.elems.into_iter().filter_map(|expr| {
if let syn::Expr::Tuple(tuple) = expr {
if tuple.elems.len() == 2 {
if let (syn::Expr::Path(api_ident), syn::Expr::Array(cmds)) =
(&tuple.elems[0], &tuple.elems[1])
{
let api_ident = api_ident.path.segments.last().unwrap().ident.clone();
let mut subcmds = Vec::new();
for cmd_expr in &cmds.elems {
if let syn::Expr::Tuple(cmd_pair) = cmd_expr {
if cmd_pair.elems.len() == 2 {
if let (syn::Expr::Path(cmd_ident), syn::Expr::Path(path)) =
(&cmd_pair.elems[0], &cmd_pair.elems[1])
{
subcmds.push((
cmd_ident.path.segments.last().unwrap().ident.clone(),
path.path.clone(),
));
}
}
}
}
return Some((api_ident, subcmds));
}
}
}
None
}).collect()
} else {
Vec::new()
};
// Generate async match arms for each API
let mut match_arms = Vec::new();
for (api_ident, subcmds) in pairs {
for (cmd_ident, path) in subcmds {
match_arms.push(quote! {
#top_enum_ident::#api_ident(inner) => {
match inner {
#path::#cmd_ident(reqs) => {
use futures::stream::{FuturesUnordered, StreamExt};
use std::sync::Arc;
use tokio::sync::Semaphore;
// Limit concurrency to 10 by default
let semaphore = Arc::new(Semaphore::new(10));
let mut tasks = FuturesUnordered::new();
for req in reqs {
let client = client.clone();
let permit = semaphore.clone().acquire_owned().await.unwrap();
let override_url = override_url.cloned();
let method_override = method_override.cloned();
let keys = keys.clone();
let sandbox = sandbox;
tasks.push(tokio::spawn(async move {
let _permit = permit;
req.send(client, override_url.as_deref(), sandbox, method_override.as_deref(), None)
.await
.map(|r| serde_json::to_value(r).unwrap_or(serde_json::Value::Null))
}));
}
let mut results = Vec::new();
while let Some(res) = tasks.next().await {
results.push(res.unwrap_or_else(|e| Err(anyhow::anyhow!("JoinError: {}", e))));
}
Ok(results)
}
}
}
});
}
}
let expanded = quote! {
#[async_trait]
impl http_core::ApiDispatch for #top_enum_ident {
type Request = ();
type Response = ();
type Error = anyhow::Error;
async fn send_all(
&self,
client: std::sync::Arc<reqwest::Client>,
keys: std::sync::Arc<http_core::Keys>,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
) -> anyhow::Result<Vec<anyhow::Result<serde_json::Value>>> {
match self {
#(#match_arms),*
}
}
}
};
expanded.into()
}