Files
http_derive/src/lib.rs

287 lines
12 KiB
Rust

use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{quote};
use syn::{
parse_macro_input, ItemStruct, ItemEnum, Fields, Type, Meta, Lit, Expr,
punctuated::Punctuated, token::Comma, MetaNameValue, parse::Parser
};
#[proc_macro_attribute]
pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
// parse the struct we're attached to
let mut input = parse_macro_input!(item as ItemStruct);
let struct_ident = &input.ident;
// defaults
let mut method_s = "GET".to_string();
let mut url_s = "".to_string();
let mut response_s = format!("{}Resp", struct_ident);
let mut error_s = "Box<dyn std::error::Error>".to_string();
// Convert attr TokenStream -> proc_macro2 TokenStream so we can inspect/try parses safely
let attr_ts: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr);
if !attr_ts.is_empty() {
// First try: parse as syn::Meta (preferred)
match syn::parse2::<Meta>(attr_ts.clone()) {
Ok(meta) => match meta {
Meta::List(meta_list) => {
// parse the inner tokens into name = value pairs
let nested: Punctuated<MetaNameValue, Comma> =
Punctuated::parse_terminated.parse2(meta_list.tokens)
.expect("failed to parse http attribute list");
for nv in nested {
if let Some(ident) = nv.path.get_ident() {
let key = ident.to_string();
// nv.value is an Expr in syn 2.x; expect Expr::Lit(Lit::Str)
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(litstr) = expr_lit.lit {
let val = litstr.value();
match key.as_str() {
"method" => method_s = val,
"url" => url_s = val,
"response" => response_s = val,
"error" => if !val.is_empty() { error_s = val },
_ => {}
}
}
}
}
}
}
Meta::NameValue(nv) => {
// handle weird case like `#[http = "foo"]` (unlikely) — accept name-value if it has ident
if let Some(ident) = nv.path.get_ident() {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(litstr) = expr_lit.lit {
let key = ident.to_string();
let val = litstr.value();
match key.as_str() {
"method" => method_s = val,
"url" => url_s = val,
"response" => response_s = val,
"error" => if !val.is_empty() { error_s = val },
_ => {}
}
}
}
}
}
Meta::Path(_) => {
// attribute present but without key/value — keep defaults
}
},
Err(_) => {
// Fallback: maybe the tokens are just a comma-separated `k = "v", ...` list without meta wrapper.
if let Ok(nested) = Punctuated::<MetaNameValue, Comma>::parse_terminated.parse2(attr_ts.clone()) {
for nv in nested {
if let Some(ident) = nv.path.get_ident() {
let key = ident.to_string();
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(litstr) = expr_lit.lit {
let val = litstr.value();
match key.as_str() {
"method" => method_s = val,
"url" => url_s = val,
"response" => response_s = val,
"error" => if !val.is_empty() { error_s = val },
_ => {}
}
}
}
}
}
}
}
}
}
// Attach #[http_error_type = "..."] for build.rs introspection
let error_lit = syn::LitStr::new(&error_s, Span::call_site());
input.attrs.push(syn::parse_quote!(#[http_error_type = #error_lit]));
// Re-attach compact http attr (so your build.rs logic still sees it)
let method_lit = syn::LitStr::new(&method_s, Span::call_site());
let url_lit = syn::LitStr::new(&url_s, Span::call_site());
let resp_lit = syn::LitStr::new(&response_s, Span::call_site());
input.attrs.push(syn::parse_quote!(#[http(method = #method_lit, url = #url_lit, response = #resp_lit)]));
// Build query param snippets for lnk_p_* fields
let mut qparam_snippets: Vec<proc_macro2::TokenStream> = Vec::new();
if let Fields::Named(fields_named) = &input.fields {
for field in &fields_named.named {
if let Some(ident) = &field.ident {
if let Some(key) = ident.to_string().strip_prefix("lnk_p_") {
let key_lit = syn::LitStr::new(key, Span::call_site());
qparam_snippets.push(quote! {
if let Some(val) = &self.#ident {
query_params.push((#key_lit.to_string(), val.to_string()));
}
});
}
}
}
}
// Parse response & error into syn::Type so complex paths (crate::X) are allowed
let response_ty: Type = syn::parse_str(&response_s).unwrap_or_else(|_| {
syn::parse_str::<Type>("serde_json::Value").expect("fallback parse")
});
let error_ty: Type = syn::parse_str(&error_s).unwrap_or_else(|_| {
syn::parse_str::<Type>("Box<dyn std::error::Error>").expect("fallback parse")
});
// Build the impl
let expanded = quote! {
#input
#[async_trait::async_trait]
impl http_core::Queryable for #struct_ident {
type R = #response_ty;
type E = #error_ty;
async fn send(
&self,
override_url: Option<&str>,
sandbox: bool,
method_override: Option<&str>,
headers: Option<Vec<(&str, &str)>>,
) -> Result<Self::R, Self::E> {
use awc::Client;
use urlencoding::encode;
use http_core::HasHttp;
let mut query_params: Vec<(String,String)> = Vec::new();
// expand lnk_p_* fields
#(#qparam_snippets)*
let mut url = if let Some(u) = override_url {
u.to_string()
} else if sandbox {
<Self as HasHttp>::sandbox_url().to_string()
} else {
<Self as HasHttp>::live_url().to_string()
};
if !query_params.is_empty() {
let qs = query_params.into_iter()
.map(|(k,v)| format!("{}={}", k, encode(&v)))
.collect::<Vec<_>>()
.join("&");
url.push('?');
url.push_str(&qs);
}
let method = method_override.unwrap_or(#method_lit);
let client = Client::default();
let mut request = match method {
"GET" => client.get(url.clone()),
"POST" => client.post(url.clone()),
"PUT" => client.put(url.clone()),
"DELETE" => client.delete(url.clone()),
"PATCH" => client.patch(url.clone()),
_ => client.get(url.clone()),
};
if let Some(hdrs) = headers {
for (k,v) in hdrs {
request = request.append_header((k, v));
}
}
let response = request.send().await.map_err(Into::into)?;
let bytes = response.body().await.map_err(Into::into)?;
let parsed: Self::R = serde_json::from_slice(&bytes).map_err(Into::into)?;
Ok(parsed)
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_enum = parse_macro_input!(item as ItemEnum);
let top_enum_ident = &input_enum.ident;
let top_variants = &input_enum.variants;
// Build outer match arms
let match_arms: Vec<_> = top_variants.iter().map(|variant| {
let variant_ident = &variant.ident;
// Expecting tuple variants like Alpaca(AlpacaCmd)
let inner_type = match &variant.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
match &fields.unnamed.first().unwrap().ty {
syn::Type::Path(p) => p.path.segments.last().unwrap().ident.clone(),
_ => panic!("Expected tuple variant with a type path"),
}
}
_ => panic!("Each variant must be a tuple variant like `Alpaca(AlpacaCmd)`"),
};
quote! {
#top_enum_ident::#variant_ident(inner) => {
match inner {
#inner_type::Bulk { input } => {
let mut reader: Box<dyn std::io::Read> = match input {
Some(path) => Box::new(std::fs::File::open(path)?),
None => Box::new(std::io::stdin()),
};
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
let queries: Vec<#inner_type> = serde_json::from_str(&buf)?;
use std::sync::Arc;
let client = Arc::new(awc::Client::default());
let keys = Arc::new(crate::load_api_keys()?);
// Spawn all queries as async tasks
let mut handles = Vec::new();
for q in queries {
let client = Arc::clone(&client);
let keys = Arc::clone(&keys);
handles.push(tokio::spawn(async move {
q.send_all(&client, &keys).await
}));
}
// Await all results and propagate first error (if any)
for h in handles {
h.await??;
}
}
other => {
let client = awc::Client::default();
let keys = crate::load_api_keys()?;
other.send_all(&client, &keys).await?;
}
}
}
}
}).collect();
// Generate the final code
let expanded = quote! {
use clap::Parser;
use std::io::Read;
#[derive(clap::Parser, Debug)]
#input_enum
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cmd = #top_enum_ident::parse();
match cmd {
#(#match_arms),*
}
Ok(())
}
};
TokenStream::from(expanded)
}