diff --git a/graphql_client_cli/Cargo.toml b/graphql_client_cli/Cargo.toml index ba2c2dbb..d583d654 100644 --- a/graphql_client_cli/Cargo.toml +++ b/graphql_client_cli/Cargo.toml @@ -20,7 +20,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = "^1.0" log = "^0.4" env_logger = "^0.6" -syn = "1.0" +syn = { version = "^2.0", features = ["full"] } [features] default = [] diff --git a/graphql_client_cli/src/generate.rs b/graphql_client_cli/src/generate.rs index 4a033829..2871887e 100644 --- a/graphql_client_cli/src/generate.rs +++ b/graphql_client_cli/src/generate.rs @@ -8,7 +8,7 @@ use std::fs::File; use std::io::Write as _; use std::path::PathBuf; use std::process::Stdio; -use syn::Token; +use syn::{token::Paren, token::Pub, VisRestricted, Visibility}; pub(crate) struct CliCodegenParams { pub query_path: PathBuf, @@ -22,6 +22,7 @@ pub(crate) struct CliCodegenParams { pub output_directory: Option, pub custom_scalars_module: Option, pub fragments_other_variant: bool, + pub external_enums: Option>, } const WARNING_SUPPRESSION: &str = "#![allow(clippy::all, warnings)]"; @@ -39,18 +40,26 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { selected_operation, custom_scalars_module, fragments_other_variant, + external_enums, } = params; let deprecation_strategy = deprecation_strategy.as_ref().and_then(|s| s.parse().ok()); let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); - options.set_module_visibility( - syn::VisPublic { - pub_token: ::default(), - } - .into(), - ); + options.set_module_visibility(match _module_visibility { + Some(v) => match v.to_lowercase().as_str() { + "pub" => Visibility::Public(Pub::default()), + "inherited" => Visibility::Inherited, + _ => Visibility::Restricted(VisRestricted { + pub_token: Pub::default(), + in_token: None, + paren_token: Paren::default(), + path: syn::parse_str(&v).unwrap(), + }), + }, + None => Visibility::Public(Pub::default()), + }); options.set_fragments_other_variant(fragments_other_variant); @@ -70,6 +79,10 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { options.set_deprecation_strategy(deprecation_strategy); } + if let Some(external_enums) = external_enums { + options.set_extern_enums(external_enums); + } + if let Some(custom_scalars_module) = custom_scalars_module { let custom_scalars_module = syn::parse_str(&custom_scalars_module) .map_err(|_| Error::message("Invalid custom scalar module path".to_owned()))?; diff --git a/graphql_client_cli/src/main.rs b/graphql_client_cli/src/main.rs index 15caa374..94a3fc0d 100644 --- a/graphql_client_cli/src/main.rs +++ b/graphql_client_cli/src/main.rs @@ -89,6 +89,9 @@ enum Cli { /// --fragments-other-variant #[clap(long = "fragments-other-variant")] fragments_other_variant: bool, + /// List of externally defined enum types. Type names must match those used in the schema exactly + #[clap(long = "external-enums", num_args(0..), action(clap::ArgAction::Append))] + external_enums: Option>, }, } @@ -126,6 +129,7 @@ fn main() -> CliResult<()> { selected_operation, custom_scalars_module, fragments_other_variant, + external_enums, } => generate::generate_code(generate::CliCodegenParams { query_path, schema_path, @@ -138,6 +142,7 @@ fn main() -> CliResult<()> { output_directory, custom_scalars_module, fragments_other_variant, + external_enums, }), } } diff --git a/graphql_client_codegen/Cargo.toml b/graphql_client_codegen/Cargo.toml index 5a20e1f7..8f30cbb4 100644 --- a/graphql_client_codegen/Cargo.toml +++ b/graphql_client_codegen/Cargo.toml @@ -16,4 +16,4 @@ proc-macro2 = { version = "^1.0", features = [] } quote = "^1.0" serde_json = "1.0" serde = { version = "^1.0", features = ["derive"] } -syn = "^1.0" +syn = { version = "^2.0", features = [ "full" ] } diff --git a/graphql_query_derive/Cargo.toml b/graphql_query_derive/Cargo.toml index 11f6e6fd..b48f5845 100644 --- a/graphql_query_derive/Cargo.toml +++ b/graphql_query_derive/Cargo.toml @@ -11,6 +11,6 @@ edition = "2018" proc-macro = true [dependencies] -syn = { version = "^1.0", features = ["extra-traits"] } +syn = { version = "^2.0", features = ["extra-traits"] } proc-macro2 = { version = "^1.0", features = [] } graphql_client_codegen = { path = "../graphql_client_codegen/", version = "0.14.0" } diff --git a/graphql_query_derive/src/attributes.rs b/graphql_query_derive/src/attributes.rs index e9d31e14..99158217 100644 --- a/graphql_query_derive/src/attributes.rs +++ b/graphql_query_derive/src/attributes.rs @@ -1,4 +1,6 @@ +use proc_macro2::TokenTree; use std::str::FromStr; +use syn::Meta; use graphql_client_codegen::deprecation::DeprecationStrategy; use graphql_client_codegen::normalization::Normalization; @@ -6,26 +8,18 @@ use graphql_client_codegen::normalization::Normalization; const DEPRECATION_ERROR: &str = "deprecated must be one of 'allow', 'deny', or 'warn'"; const NORMALIZATION_ERROR: &str = "normalization must be one of 'none' or 'rust'"; -/// The `graphql` attribute as a `syn::Path`. -fn path_to_match() -> syn::Path { - syn::parse_str("graphql").expect("`graphql` is a valid path") -} - pub fn ident_exists(ast: &syn::DeriveInput, ident: &str) -> Result<(), syn::Error> { - let graphql_path = path_to_match(); let attribute = ast .attrs .iter() - .find(|attr| attr.path == graphql_path) + .find(|attr| attr.path().is_ident("graphql")) .ok_or_else(|| syn::Error::new_spanned(ast, "The graphql attribute is missing"))?; - if let syn::Meta::List(items) = &attribute.parse_meta().expect("Attribute is well formatted") { - for item in items.nested.iter() { - if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = item { - if let Some(ident_) = path.get_ident() { - if ident_ == ident { - return Ok(()); - } + if let Meta::List(list) = &attribute.meta { + for item in list.tokens.clone().into_iter() { + if let TokenTree::Ident(ident_) = item { + if ident_ == ident { + return Ok(()); } } } @@ -39,21 +33,21 @@ pub fn ident_exists(ast: &syn::DeriveInput, ident: &str) -> Result<(), syn::Erro /// Extract an configuration parameter specified in the `graphql` attribute. pub fn extract_attr(ast: &syn::DeriveInput, attr: &str) -> Result { - let attributes = &ast.attrs; - let graphql_path = path_to_match(); - let attribute = attributes + let attribute = ast + .attrs .iter() - .find(|attr| attr.path == graphql_path) + .find(|a| a.path().is_ident("graphql")) .ok_or_else(|| syn::Error::new_spanned(ast, "The graphql attribute is missing"))?; - if let syn::Meta::List(items) = &attribute.parse_meta().expect("Attribute is well formatted") { - for item in items.nested.iter() { - if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = item { - let syn::MetaNameValue { path, lit, .. } = name_value; - if let Some(ident) = path.get_ident() { - if ident == attr { - if let syn::Lit::Str(lit) = lit { - return Ok(lit.value()); - } + + if let Meta::List(list) = &attribute.meta { + let mut iter = list.tokens.clone().into_iter(); + while let Some(item) = iter.next() { + if let TokenTree::Ident(ident) = item { + if ident == attr { + iter.next(); + if let Some(TokenTree::Literal(lit)) = iter.next() { + let lit_str: syn::LitStr = syn::parse_str(&lit.to_string())?; + return Ok(lit_str.value()); } } } @@ -68,38 +62,41 @@ pub fn extract_attr(ast: &syn::DeriveInput, attr: &str) -> Result Result, syn::Error> { - let attributes = &ast.attrs; - let graphql_path = path_to_match(); - let attribute = attributes + let attribute = ast + .attrs .iter() - .find(|attr| attr.path == graphql_path) + .find(|a| a.path().is_ident("graphql")) .ok_or_else(|| syn::Error::new_spanned(ast, "The graphql attribute is missing"))?; - if let syn::Meta::List(items) = &attribute.parse_meta().expect("Attribute is well formatted") { - for item in items.nested.iter() { - if let syn::NestedMeta::Meta(syn::Meta::List(value_list)) = item { - if let Some(ident) = value_list.path.get_ident() { - if ident == attr { - return value_list - .nested - .iter() - .map(|lit| { - if let syn::NestedMeta::Lit(syn::Lit::Str(lit)) = lit { - Ok(lit.value()) - } else { - Err(syn::Error::new_spanned( - lit, - "Attribute inside value list must be a literal", - )) - } - }) - .collect(); + + let mut result = Vec::new(); + + if let Meta::List(list) = &attribute.meta { + let mut iter = list.tokens.clone().into_iter(); + while let Some(item) = iter.next() { + if let TokenTree::Ident(ident) = item { + if ident == attr { + if let Some(TokenTree::Group(group)) = iter.next() { + for token in group.stream() { + if let TokenTree::Literal(lit) = token { + let lit_str: syn::LitStr = syn::parse_str(&lit.to_string())?; + result.push(lit_str.value()); + } + } + return Ok(result); } } } } } - Err(syn::Error::new_spanned(ast, "Attribute not found")) + if result.is_empty() { + Err(syn::Error::new_spanned( + ast, + format!("Attribute list `{}` not found or empty", attr), + )) + } else { + Ok(result) + } } /// Get the deprecation from a struct attribute in the derive case. @@ -278,4 +275,24 @@ mod test { let parsed = syn::parse_str(input).unwrap(); assert!(!extract_skip_serializing_none(&parsed)); } + + #[test] + fn test_external_enums() { + let input = r#" + #[derive(Serialize, Deserialize, Debug)] + #[derive(GraphQLQuery)] + #[graphql( + schema_path = "x", + query_path = "x", + extern_enums("Direction", "DistanceUnit"), + )] + struct MyQuery; + "#; + let parsed: syn::DeriveInput = syn::parse_str(input).unwrap(); + + assert_eq!( + extract_attr_list(&parsed, "extern_enums").ok().unwrap(), + vec!["Direction", "DistanceUnit"], + ); + } }