From 0ba23fad0ee8ee05ee56b1a536a41f4578d78a34 Mon Sep 17 00:00:00 2001 From: buckn Date: Thu, 17 Jul 2025 12:18:58 -0400 Subject: [PATCH] ud --- src/lib.rs | 203 +++++++++++++++++++++++++++-------------------------- 1 file changed, 104 insertions(+), 99 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 763e512..a456c44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,120 +147,125 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream { let input_enum = parse_macro_input!(item as ItemEnum); - let enum_name = &input_enum.ident; - let variants = &input_enum.variants; + let top_enum_ident = &input_enum.ident; // e.g., "Cmd" + let top_variants = &input_enum.variants; - // Match arms for regular command variants (Single, Asset, etc.) - let regular_arms = variants.iter().filter_map(|v| { - let v_name = &v.ident; + // For each variant like Alpaca(AlpacaCmd) + let outer_match_arms = top_variants.iter().map(|v| { + let variant_ident = &v.ident; // e.g., "Alpaca" - // Skip Bulk variant — we handle it separately - if v_name == "Bulk" { - return None; + // Extract the inner sub-enum type like AlpacaCmd + if let syn::Fields::Unnamed(fields) = &v.fields { + if let Some(field) = fields.unnamed.first() { + if let syn::Type::Path(inner_type) = &field.ty { + let inner_type_ident = &inner_type.path.segments.last().unwrap().ident; + + // Match arms inside the nested enum (AlpacaCmd) + let inner_match_arm = quote! { + match #inner_type_ident::parse() { + #inner_type_ident::Bulk { input } => { + // Bulk: read and parse Vec + let mut reader: Box = match input { + Some(path) => Box::new(std::fs::File::open(path)?), + None => Box::new(std::io::stdin()), + }; + let mut buf = String::new(); + reader.read_to_string(&mut buf)?; + let queries: Vec<#inner_type_ident> = serde_json::from_str(&buf)?; + + use std::sync::Arc; + let client = Arc::new(awc::Client::default()); + let api_keys = Arc::new(crate::load_api_keys()?); + + const THREADS: usize = 4; + let total = queries.len(); + let per_thread = total / THREADS; + let shared_queries = Arc::new(queries); + + let mut handles = Vec::new(); + for i in 0..THREADS { + let queries = Arc::clone(&shared_queries); + let client = Arc::clone(&client); + let keys = Arc::clone(&api_keys); + let start = i * per_thread; + let end = if i == THREADS - 1 { total } else { start + per_thread }; + + let handle = std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + for q in &queries[start..end] { + rt.block_on(q.send_all(&client, &keys)); + } + }); + + handles.push(handle); + } + + for h in handles { + h.join().expect("Thread panicked"); + } + } + other => { + let client = awc::Client::default(); + let keys = crate::load_api_keys()?; + other.send_all(&client, &keys).await?; + } + } + }; + + // Wrap the outer enum match (Cmd::Alpaca(inner)) + return quote! { + #top_enum_ident::#variant_ident(inner) => { + #inner_match_arm + } + }; + } + } } - Some(quote! { - #enum_name::#v_name(req) => { - let res = req.send(client.clone(), &api_key).await?; - let body = res.body().await?; - println!("{}", std::str::from_utf8(&body)?); - } - }) + panic!("Each outer enum variant must be a tuple variant like `Alpaca(AlpacaCmd)`"); }); + // Generate the final program let expanded = quote! { - #[derive(structopt::StructOpt, Debug)] + use clap::Parser; + use std::io::Read; + + #[derive(clap::Parser, Debug)] #input_enum #[tokio::main] async fn main() -> Result<(), Box> { - use structopt::StructOpt; - use std::fs::File; - use std::io::{BufReader, Read}; - use std::sync::Arc; - use std::thread; - - const THREADS: usize = 4; - - // Initialize shared HTTP client and API key - let client = Arc::new(awc::Client::default()); - let api_key = std::env::var("APCA_API_KEY_ID")?; - let cmd = #enum_name::from_args(); - + let cmd = #top_enum_ident::parse(); match cmd { - #(#regular_arms)* - - #enum_name::Bulk { input } => { - // Choose input source: file or stdin - let mut reader: Box = match input { - Some(path) => Box::new(File::open(path)?), - None => Box::new(std::io::stdin()), - }; - - // Read input JSON into buffer - let mut buf = String::new(); - reader.read_to_string(&mut buf)?; - - // Deserialize into Vec - let queries: Vec = serde_json::from_str(&buf)?; - let total = queries.len(); - - if total == 0 { - eprintln!("No queries provided."); - return Ok(()); - } - - let shared_queries = Arc::new(queries); - let shared_key = Arc::new(api_key); - - let per_thread = total / THREADS; - - let mut handles = Vec::with_capacity(THREADS); - for i in 0..THREADS { - let queries_clone = Arc::clone(&shared_queries); - let client_clone = Arc::clone(&client); - let key_clone = Arc::clone(&shared_key); - - let start_index = i * per_thread; - let end_index = if i == THREADS - 1 { - total // Last thread gets the remainder - } else { - start_index + per_thread - }; - - let handle = thread::spawn(move || { - let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime"); - for idx in start_index..end_index { - let query = &queries_clone[idx]; - let send_result = rt.block_on(query.send(client_clone.clone(), &key_clone)); - - match send_result { - Ok(response) => { - let body_result = rt.block_on(response.body()); - match body_result { - Ok(body) => println!("{}", String::from_utf8_lossy(&body)), - Err(e) => eprintln!("Error reading response body: {:?}", e), - } - } - Err(e) => { - eprintln!("Request failed: {:?}", e); - } - } - } - }); - - handles.push(handle); - } - - // Wait for all threads to complete - for handle in handles { - handle.join().expect("A thread panicked"); - } - } + #(#outer_match_arms),* } - Ok(()) } + + // Helper trait to unify async calls on sub-commands + trait ApiDispatch { + fn send_all( + &self, + client: &awc::Client, + keys: &std::collections::HashMap, + ) -> std::pin::Pin>> + Send>>; + } + + // Implement ApiDispatch for every subcommand variant + #(impl ApiDispatch for #top_enum_ident { + fn send_all( + &self, + client: &awc::Client, + keys: &std::collections::HashMap, + ) -> std::pin::Pin>> + Send>> { + Box::pin(async move { + match self { + #(#outer_match_arms),* + } + Ok(()) + }) + } + })* }; TokenStream::from(expanded)