From ac0ace0b948a50c4f854440ccc71541de6e00592 Mon Sep 17 00:00:00 2001 From: buckn Date: Fri, 11 Jul 2025 11:58:29 -0400 Subject: [PATCH] added macro that auto impls the structopt query sending --- Cargo.toml | 2 +- src/lib.rs | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ef580d5..5675a3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,4 +9,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.95" quote = "1.0.40" -syn = "2.0.104" +syn = { version = "2.0.104", features = ["full"] } diff --git a/src/lib.rs b/src/lib.rs index e6e67c3..8c29974 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, DeriveInput, Data, Fields, Lit}; +use syn::{parse_macro_input, ItemEnum, Lit, DeriveInput, Fields, Data}; #[proc_macro_derive(HttpRequest, attributes(http_get))] pub fn derive_http_get_request(input: TokenStream) -> TokenStream { @@ -191,3 +191,125 @@ pub fn derive_response_vec(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } + +#[proc_macro_attribute] +pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input_enum = parse_macro_input!(item as ItemEnum); + let enum_name = &input_enum.ident; + let variants = &input_enum.variants; + + // Match arms for regular command variants (Single, Asset, etc.) + let regular_arms = variants.iter().filter_map(|v| { + let v_name = &v.ident; + + // Skip Bulk variant — we handle it separately + if v_name == "Bulk" { + return None; + } + + Some(quote! { + #enum_name::#v_name(req) => { + let res = req.send(client.clone(), &api_key).await?; + let body = res.body().await?; + println!("{}", std::str::from_utf8(&body)?); + } + }) + }); + + let expanded = quote! { + #[derive(structopt::StructOpt, Debug)] + #input_enum + + #[tokio::main] + async fn main() -> Result<(), Box> { + use structopt::StructOpt; + use std::fs::File; + use std::io::{BufReader, Read}; + use std::sync::Arc; + use std::thread; + + const THREADS: usize = 4; + + // Initialize shared HTTP client and API key + let client = Arc::new(awc::Client::default()); + let api_key = std::env::var("APCA_API_KEY_ID")?; + let cmd = #enum_name::from_args(); + + match cmd { + #(#regular_arms)* + + #enum_name::Bulk { input } => { + // Choose input source: file or stdin + let mut reader: Box = match input { + Some(path) => Box::new(File::open(path)?), + None => Box::new(std::io::stdin()), + }; + + // Read input JSON into buffer + let mut buf = String::new(); + reader.read_to_string(&mut buf)?; + + // Deserialize into Vec + let queries: Vec = serde_json::from_str(&buf)?; + let total = queries.len(); + + if total == 0 { + eprintln!("No queries provided."); + return Ok(()); + } + + let shared_queries = Arc::new(queries); + let shared_key = Arc::new(api_key); + + let per_thread = total / THREADS; + + let mut handles = Vec::with_capacity(THREADS); + for i in 0..THREADS { + let queries_clone = Arc::clone(&shared_queries); + let client_clone = Arc::clone(&client); + let key_clone = Arc::clone(&shared_key); + + let start_index = i * per_thread; + let end_index = if i == THREADS - 1 { + total // Last thread gets the remainder + } else { + start_index + per_thread + }; + + let handle = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime"); + for idx in start_index..end_index { + let query = &queries_clone[idx]; + let send_result = rt.block_on(query.send(client_clone.clone(), &key_clone)); + + match send_result { + Ok(response) => { + let body_result = rt.block_on(response.body()); + match body_result { + Ok(body) => println!("{}", String::from_utf8_lossy(&body)), + Err(e) => eprintln!("Error reading response body: {:?}", e), + } + } + Err(e) => { + eprintln!("Request failed: {:?}", e); + } + } + } + }); + + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().expect("A thread panicked"); + } + } + } + + Ok(()) + } + }; + + TokenStream::from(expanded) +}