hopefully the dispatch macro is properly multithreaded now
This commit is contained in:
1398
Cargo.lock
generated
1398
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -12,3 +12,4 @@ proc-macro2 = "1.0.95"
|
|||||||
quote = "1.0.40"
|
quote = "1.0.40"
|
||||||
syn = { version = "2.0.104", features = ["full"] }
|
syn = { version = "2.0.104", features = ["full"] }
|
||||||
http_core = { path = "../http_core" }
|
http_core = { path = "../http_core" }
|
||||||
|
anyhow = "1.0.100"
|
||||||
|
|||||||
332
src/lib.rs
332
src/lib.rs
@ -1,22 +1,22 @@
|
|||||||
// src/lib.rs -- proc-macro crate
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use proc_macro2::Span;
|
use proc_macro2::Span;
|
||||||
use quote::quote;
|
use quote::quote;
|
||||||
use syn::{
|
use syn::{
|
||||||
parse_macro_input, ItemStruct, ItemEnum, Fields, Type, Meta, Lit, Expr,
|
parse_macro_input, Fields, ItemStruct, Lit, Meta, MetaNameValue, Expr, Type,
|
||||||
punctuated::Punctuated, token::Comma, MetaNameValue,
|
punctuated::Punctuated, token::Comma, parse::Parser,
|
||||||
// Important: Parser trait brings `parse2` into scope for parser functions
|
|
||||||
parse::Parser,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// ---------------------------------------------------------------------------
|
||||||
/// #[http(...)] attribute macro
|
/// #[http(...)] attribute macro
|
||||||
/// Accepts (with defaults):
|
/// Generates HasHttp + Queryable impls (reqwest-based).
|
||||||
|
///
|
||||||
|
/// Defaults:
|
||||||
/// - method = "GET"
|
/// - method = "GET"
|
||||||
/// - live_url = ""
|
/// - live_url = ""
|
||||||
/// - sandbox_url = ""
|
/// - sandbox_url = ""
|
||||||
/// - response = "<StructName>Resp"
|
/// - response = "<StructName>Resp"
|
||||||
/// - error = "Box<dyn std::error::Error>"
|
/// ---------------------------------------------------------------------------
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
let input = parse_macro_input!(item as ItemStruct);
|
let input = parse_macro_input!(item as ItemStruct);
|
||||||
@ -27,33 +27,20 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
let mut live_url_s = "".to_string();
|
let mut live_url_s = "".to_string();
|
||||||
let mut sandbox_url_s = "".to_string();
|
let mut sandbox_url_s = "".to_string();
|
||||||
let mut response_s = format!("{}Resp", struct_ident);
|
let mut response_s = format!("{}Resp", struct_ident);
|
||||||
let mut error_s = "Box<dyn std::error::Error + Send + Sync>".to_string();
|
|
||||||
|
|
||||||
// turn attr into TokenStream2
|
// parse #[http(...)] arguments
|
||||||
let attr_ts: proc_macro2::TokenStream = attr.into();
|
let attr_ts: proc_macro2::TokenStream = attr.into();
|
||||||
|
|
||||||
if !attr_ts.is_empty() {
|
if !attr_ts.is_empty() {
|
||||||
if let Ok(meta) = syn::parse2::<Meta>(attr_ts.clone()) {
|
if let Ok(meta) = syn::parse2::<Meta>(attr_ts.clone()) {
|
||||||
if let Meta::List(list) = meta {
|
if let Meta::List(list) = meta {
|
||||||
let nested: Punctuated<MetaNameValue, Comma> =
|
let nested: Punctuated<MetaNameValue, Comma> =
|
||||||
Punctuated::parse_terminated.parse2(list.tokens)
|
Punctuated::parse_terminated
|
||||||
|
.parse2(list.tokens)
|
||||||
.expect("failed to parse http attribute list");
|
.expect("failed to parse http attribute list");
|
||||||
for nv in nested {
|
for nv in nested {
|
||||||
if let Some(ident) = nv.path.get_ident() {
|
if let Some(ident) = nv.path.get_ident() {
|
||||||
let key = ident.to_string();
|
let key = ident.to_string();
|
||||||
if key == "error" {
|
if let Expr::Lit(expr_lit) = &nv.value {
|
||||||
match &nv.value {
|
|
||||||
Expr::Lit(expr_lit) => {
|
|
||||||
if let Lit::Str(ls) = &expr_lit.lit {
|
|
||||||
error_s = ls.value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Expr::Path(p) => {
|
|
||||||
error_s = quote!(#p).to_string();
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
} else if let Expr::Lit(expr_lit) = &nv.value {
|
|
||||||
if let Lit::Str(ls) = &expr_lit.lit {
|
if let Lit::Str(ls) = &expr_lit.lit {
|
||||||
let val = ls.value();
|
let val = ls.value();
|
||||||
match key.as_str() {
|
match key.as_str() {
|
||||||
@ -71,7 +58,7 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect lnk_p_* -> query
|
// collect query param fields (lnk_p_*)
|
||||||
let mut qparam_snippets = Vec::new();
|
let mut qparam_snippets = Vec::new();
|
||||||
if let Fields::Named(fields_named) = &input.fields {
|
if let Fields::Named(fields_named) = &input.fields {
|
||||||
for field in &fields_named.named {
|
for field in &fields_named.named {
|
||||||
@ -88,16 +75,15 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse response & error types
|
// response type (default = serde_json::Value)
|
||||||
let response_ty: Type = syn::parse_str(&response_s)
|
let response_ty: Type = syn::parse_str(&response_s)
|
||||||
.unwrap_or_else(|_| syn::parse_str("serde_json::Value").unwrap());
|
.unwrap_or_else(|_| syn::parse_str("serde_json::Value").unwrap());
|
||||||
let error_ty: Type = syn::parse_str(&error_s)
|
|
||||||
.unwrap_or_else(|_| syn::parse_str("Box<dyn std::error::Error + Send + Sync>").unwrap());
|
|
||||||
|
|
||||||
let method_lit = syn::LitStr::new(&method_s, Span::call_site());
|
let method_lit = syn::LitStr::new(&method_s, Span::call_site());
|
||||||
let live_lit = syn::LitStr::new(&live_url_s, Span::call_site());
|
let live_lit = syn::LitStr::new(&live_url_s, Span::call_site());
|
||||||
let sandbox_lit = syn::LitStr::new(&sandbox_url_s, Span::call_site());
|
let sandbox_lit = syn::LitStr::new(&sandbox_url_s, Span::call_site());
|
||||||
|
|
||||||
|
// ✅ Final merged + polished implementation
|
||||||
let expanded = quote! {
|
let expanded = quote! {
|
||||||
#input
|
#input
|
||||||
|
|
||||||
@ -107,60 +93,74 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
fn sandbox_url() -> &'static str { #sandbox_lit }
|
fn sandbox_url() -> &'static str { #sandbox_lit }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait]
|
||||||
impl http_core::Queryable for #struct_ident {
|
impl http_core::Queryable for #struct_ident {
|
||||||
type R = #response_ty;
|
type Response = #response_ty;
|
||||||
type E = #error_ty;
|
|
||||||
|
|
||||||
async fn send(
|
async fn send(
|
||||||
&self,
|
&self,
|
||||||
|
client: std::sync::Arc<reqwest::Client>,
|
||||||
override_url: Option<&str>,
|
override_url: Option<&str>,
|
||||||
sandbox: bool,
|
sandbox: bool,
|
||||||
method_override: Option<&str>,
|
method_override: Option<&str>,
|
||||||
headers: Option<Vec<(&str, &str)>>,
|
headers: Option<Vec<(&str, &str)>>,
|
||||||
) -> Result<Self::R, Self::E> {
|
) -> anyhow::Result<Self::Response> {
|
||||||
use awc::Client;
|
use http_core::{HasHttp, replace_path_params, append_query_params};
|
||||||
use http_core::HasHttp;
|
use anyhow::{Context, anyhow};
|
||||||
use urlencoding::encode;
|
use serde_json::to_vec;
|
||||||
|
|
||||||
let mut query_params: Vec<(String, String)> = Vec::new();
|
let mut query_params: Vec<(String, String)> = Vec::new();
|
||||||
#(#qparam_snippets)*
|
#(#qparam_snippets)*
|
||||||
|
|
||||||
let mut url = if let Some(u) = override_url {
|
// base URL resolution
|
||||||
u.to_string()
|
let mut url = override_url.unwrap_or_else(|| {
|
||||||
} else if sandbox {
|
if sandbox {
|
||||||
<Self as HasHttp>::sandbox_url().to_string()
|
<Self as HasHttp>::sandbox_url()
|
||||||
} else {
|
} else {
|
||||||
<Self as HasHttp>::live_url().to_string()
|
<Self as HasHttp>::live_url()
|
||||||
};
|
}
|
||||||
|
}).to_string();
|
||||||
|
|
||||||
if !query_params.is_empty() {
|
// replace path params if any exist in the template
|
||||||
let qs = query_params.into_iter()
|
url = replace_path_params(&url, &[], self)
|
||||||
.map(|(k,v)| format!("{}={}", k, encode(&v)))
|
.context("Failed to replace path params")?;
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("&");
|
|
||||||
url.push('?');
|
|
||||||
url.push_str(&qs);
|
|
||||||
}
|
|
||||||
|
|
||||||
let method = method_override.unwrap_or(#method_lit);
|
// append query params if present
|
||||||
let client = Client::default();
|
url = append_query_params(&url, &query_params.iter().map(|(k, _)| k.clone()).collect::<Vec<_>>(), self)
|
||||||
let mut req = match method {
|
.context("Failed to append query params")?;
|
||||||
"GET" => client.get(url.clone()),
|
|
||||||
"POST" => client.post(url.clone()),
|
let method_str = method_override.unwrap_or(#method_lit);
|
||||||
"PUT" => client.put(url.clone()),
|
let method = reqwest::Method::from_bytes(method_str.as_bytes())
|
||||||
"DELETE" => client.delete(url.clone()),
|
.unwrap_or(reqwest::Method::GET);
|
||||||
"PATCH" => client.patch(url.clone()),
|
|
||||||
_ => client.get(url.clone()),
|
// build request
|
||||||
};
|
let mut req = client.request(method.clone(), &url);
|
||||||
|
|
||||||
if let Some(hs) = headers {
|
if let Some(hs) = headers {
|
||||||
for (k, v) in hs { req = req.append_header((k, v)); }
|
for (k, v) in hs {
|
||||||
|
req = req.header(k, v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp = req.send().await.map_err(Into::into)?;
|
// send body if applicable
|
||||||
let body = resp.body().await.map_err(Into::into)?;
|
let resp = if matches!(method, reqwest::Method::POST | reqwest::Method::PUT | reqwest::Method::PATCH) {
|
||||||
let parsed: Self::R = serde_json::from_slice(&body).map_err(Into::into)?;
|
let body = to_vec(self).context("Failed to serialize request body")?;
|
||||||
|
req = req.body(body);
|
||||||
|
req.send().await.context("HTTP request failed")?
|
||||||
|
} else {
|
||||||
|
req.send().await.context("HTTP request failed")?
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle response
|
||||||
|
let status = resp.status();
|
||||||
|
let bytes = resp.bytes().await.context("Failed to read response body")?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(anyhow!("HTTP {}: {}", status, String::from_utf8_lossy(&bytes)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let parsed: Self::Response = serde_json::from_slice(&bytes)
|
||||||
|
.context("Failed to deserialize response")?;
|
||||||
Ok(parsed)
|
Ok(parsed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -169,86 +169,136 @@ pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
TokenStream::from(expanded)
|
TokenStream::from(expanded)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// CLI generator macro (keeps your existing CLI structure)
|
#[proc_macro]
|
||||||
#[proc_macro_attribute]
|
pub fn dispatch(input: TokenStream) -> TokenStream {
|
||||||
pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
use quote::format_ident;
|
||||||
let input_enum = parse_macro_input!(item as ItemEnum);
|
let input = parse_macro_input!(input as syn::ExprTuple);
|
||||||
let top_enum_ident = &input_enum.ident;
|
|
||||||
let top_variants = &input_enum.variants;
|
|
||||||
|
|
||||||
// Build outer match arms same as your previous implementation
|
if input.elems.len() != 2 {
|
||||||
let match_arms: Vec<_> = top_variants.iter().map(|variant| {
|
return syn::Error::new_spanned(input, "Expected (TopEnumIdent, hashmap!{ ... })")
|
||||||
let variant_ident = &variant.ident;
|
.to_compile_error()
|
||||||
|
.into();
|
||||||
|
}
|
||||||
|
|
||||||
// Expect tuple variant like Alpaca(AlpacaCmd)
|
// Extract top enum name
|
||||||
let inner_type = match &variant.fields {
|
let top_enum_ident = match &input.elems[0] {
|
||||||
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
syn::Expr::Path(p) => &p.path.segments.last().unwrap().ident,
|
||||||
match &fields.unnamed.first().unwrap().ty {
|
other => {
|
||||||
syn::Type::Path(p) => p.path.segments.last().unwrap().ident.clone(),
|
return syn::Error::new_spanned(other, "First element must be a type path").to_compile_error().into();
|
||||||
_ => panic!("Expected tuple variant with a type path"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => panic!("Each variant must be a tuple variant like `Alpaca(AlpacaCmd)`"),
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
#top_enum_ident::#variant_ident(inner) => {
|
|
||||||
match inner {
|
|
||||||
#inner_type::Bulk { input } => {
|
|
||||||
let mut reader: Box<dyn std::io::Read> = match input {
|
|
||||||
Some(path) => Box::new(std::fs::File::open(path)?),
|
|
||||||
None => Box::new(std::io::stdin()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut buf = String::new();
|
|
||||||
reader.read_to_string(&mut buf)?;
|
|
||||||
let queries: Vec<#inner_type> = serde_json::from_str(&buf)?;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
let client = Arc::new(awc::Client::default());
|
|
||||||
let keys = Arc::new(crate::load_api_keys()?);
|
|
||||||
|
|
||||||
// Spawn tasks for each query and await
|
|
||||||
let mut handles = Vec::new();
|
|
||||||
for q in queries {
|
|
||||||
let client = Arc::clone(&client);
|
|
||||||
let keys = Arc::clone(&keys);
|
|
||||||
handles.push(tokio::spawn(async move {
|
|
||||||
q.send_all(&client, &keys, None, false, None).await
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
for h in handles {
|
|
||||||
h.await??;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
other => {
|
|
||||||
let client = awc::Client::default();
|
|
||||||
let keys = crate::load_api_keys()?;
|
|
||||||
other.send_all(&client, &keys, None, false, None).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
let expanded = quote! {
|
|
||||||
use clap::Parser;
|
|
||||||
use std::io::Read;
|
|
||||||
|
|
||||||
#[derive(clap::Parser, Debug)]
|
|
||||||
#input_enum
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
let cmd = #top_enum_ident::parse();
|
|
||||||
match cmd {
|
|
||||||
#(#match_arms),*
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TokenStream::from(expanded)
|
// Parse second argument (the hashmap! macro)
|
||||||
|
let map_expr = &input.elems[1];
|
||||||
|
let map_macro = match map_expr {
|
||||||
|
syn::Expr::Macro(m) if m.mac.path.is_ident("hashmap") => m,
|
||||||
|
other => {
|
||||||
|
return syn::Error::new_spanned(other, "Second element must be a hashmap! macro").to_compile_error().into();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokens: proc_macro2::TokenStream = map_macro.mac.tokens.clone();
|
||||||
|
let parser = syn::parse2::<syn::ExprArray>(tokens);
|
||||||
|
|
||||||
|
let pairs: Vec<(syn::Ident, Vec<(syn::Ident, syn::Path)>)> = if let Ok(array) = parser {
|
||||||
|
array.elems.into_iter().filter_map(|expr| {
|
||||||
|
if let syn::Expr::Tuple(tuple) = expr {
|
||||||
|
if tuple.elems.len() == 2 {
|
||||||
|
if let (syn::Expr::Path(api_ident), syn::Expr::Array(cmds)) =
|
||||||
|
(&tuple.elems[0], &tuple.elems[1])
|
||||||
|
{
|
||||||
|
let api_ident = api_ident.path.segments.last().unwrap().ident.clone();
|
||||||
|
let mut subcmds = Vec::new();
|
||||||
|
for cmd_expr in &cmds.elems {
|
||||||
|
if let syn::Expr::Tuple(cmd_pair) = cmd_expr {
|
||||||
|
if cmd_pair.elems.len() == 2 {
|
||||||
|
if let (syn::Expr::Path(cmd_ident), syn::Expr::Path(path)) =
|
||||||
|
(&cmd_pair.elems[0], &cmd_pair.elems[1])
|
||||||
|
{
|
||||||
|
subcmds.push((
|
||||||
|
cmd_ident.path.segments.last().unwrap().ident.clone(),
|
||||||
|
path.path.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Some((api_ident, subcmds));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}).collect()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate async match arms for each API
|
||||||
|
let mut match_arms = Vec::new();
|
||||||
|
for (api_ident, subcmds) in pairs {
|
||||||
|
for (cmd_ident, path) in subcmds {
|
||||||
|
match_arms.push(quote! {
|
||||||
|
#top_enum_ident::#api_ident(inner) => {
|
||||||
|
match inner {
|
||||||
|
#path::#cmd_ident(reqs) => {
|
||||||
|
use futures::stream::{FuturesUnordered, StreamExt};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
|
// Limit concurrency to 10 by default
|
||||||
|
let semaphore = Arc::new(Semaphore::new(10));
|
||||||
|
|
||||||
|
let mut tasks = FuturesUnordered::new();
|
||||||
|
for req in reqs {
|
||||||
|
let client = client.clone();
|
||||||
|
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||||
|
let override_url = override_url.cloned();
|
||||||
|
let method_override = method_override.cloned();
|
||||||
|
let keys = keys.clone();
|
||||||
|
let sandbox = sandbox;
|
||||||
|
|
||||||
|
tasks.push(tokio::spawn(async move {
|
||||||
|
let _permit = permit;
|
||||||
|
req.send(client, override_url.as_deref(), sandbox, method_override.as_deref(), None)
|
||||||
|
.await
|
||||||
|
.map(|r| serde_json::to_value(r).unwrap_or(serde_json::Value::Null))
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
while let Some(res) = tasks.next().await {
|
||||||
|
results.push(res.unwrap_or_else(|e| Err(anyhow::anyhow!("JoinError: {}", e))));
|
||||||
|
}
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let expanded = quote! {
|
||||||
|
#[async_trait]
|
||||||
|
impl http_core::ApiDispatch for #top_enum_ident {
|
||||||
|
type Request = ();
|
||||||
|
type Response = ();
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
async fn send_all(
|
||||||
|
&self,
|
||||||
|
client: std::sync::Arc<reqwest::Client>,
|
||||||
|
keys: std::sync::Arc<http_core::Keys>,
|
||||||
|
override_url: Option<&str>,
|
||||||
|
sandbox: bool,
|
||||||
|
method_override: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<anyhow::Result<serde_json::Value>>> {
|
||||||
|
match self {
|
||||||
|
#(#match_arms),*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
expanded.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user