diff --git a/rust/extractor/macros/src/lib.rs b/rust/extractor/macros/src/lib.rs index c70856aad9f3..b79f0cc29391 100644 --- a/rust/extractor/macros/src/lib.rs +++ b/rust/extractor/macros/src/lib.rs @@ -19,11 +19,17 @@ pub fn extractor_cli_config(_attr: TokenStream, item: TokenStream) -> TokenStrea .fields .iter() .map(|f| { + let ty_tip = get_type_tip(&f.ty); if f.ident.as_ref().is_some_and(|i| i != "inputs") - && get_type_tip(&f.ty).is_some_and(|i| i == "Vec") + && ty_tip.is_some_and(|i| i == "Vec") { quote! { - #[serde(deserialize_with="deserialize_newline_or_comma_separated")] + #[serde(deserialize_with="deserialize::deserialize_newline_or_comma_separated_vec")] + #f + } + } else if ty_tip.is_some_and(|i| i == "FxHashMap" || i == "HashMap") { + quote! { + #[serde(deserialize_with="deserialize::deserialize_newline_or_comma_separated_map")] #f } } else { @@ -60,7 +66,7 @@ pub fn extractor_cli_config(_attr: TokenStream, item: TokenStream) -> TokenStrea quote! { #f } - } else if type_tip.is_some_and(|i| i == "Vec") { + } else if type_tip.is_some_and(|i| i == "Vec" || i == "FxHashMap" || i == "HashMap") { quote! { #[arg(long)] #id: Option diff --git a/rust/extractor/src/config.rs b/rust/extractor/src/config.rs index 82568b64553f..c87af7e77280 100644 --- a/rust/extractor/src/config.rs +++ b/rust/extractor/src/config.rs @@ -1,9 +1,8 @@ -mod deserialize_vec; +mod deserialize; use anyhow::Context; use clap::Parser; use codeql_extractor::trap; -use deserialize_vec::deserialize_newline_or_comma_separated; use figment::{ providers::{Env, Format, Serialized, Yaml}, value::Value, @@ -13,14 +12,15 @@ use itertools::Itertools; use ra_ap_cfg::{CfgAtom, CfgDiff}; use ra_ap_ide_db::FxHashMap; use ra_ap_intern::Symbol; -use ra_ap_paths::{AbsPath, Utf8PathBuf}; +use ra_ap_load_cargo::{LoadCargoConfig, ProcMacroServerChoice}; +use ra_ap_paths::{AbsPath, AbsPathBuf, Utf8PathBuf}; use ra_ap_project_model::{CargoConfig, CargoFeatures, CfgOverrides, RustLibSource, Sysroot}; use rust_extractor_macros::extractor_cli_config; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::fmt::Debug; use std::ops::Not; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; #[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize, Clone, Copy, clap::ValueEnum)] #[serde(rename_all = "lowercase")] @@ -50,6 +50,9 @@ pub struct Config { pub cargo_target: Option, pub cargo_features: Vec, pub cargo_cfg_overrides: Vec, + pub cargo_extra_env: FxHashMap, + pub cargo_extra_args: Vec, + pub cargo_all_targets: bool, pub logging_flamegraph: Option, pub logging_verbosity: Option, pub compression: Compression, @@ -57,6 +60,12 @@ pub struct Config { pub qltest: bool, pub qltest_cargo_check: bool, pub qltest_dependencies: Vec, + pub sysroot: Option, + pub sysroot_src: Option, + pub rustc_src: Option, + pub build_script_command: Vec, + pub extra_includes: Vec, + pub proc_macro_server: Option, } impl Config { @@ -92,44 +101,86 @@ impl Config { figment.extract().context("loading configuration") } - pub fn to_cargo_config(&self, dir: &AbsPath) -> CargoConfig { - let sysroot = Sysroot::discover(dir, &FxHashMap::default()); - let sysroot_src = sysroot.src_root().map(ToOwned::to_owned); - let sysroot = sysroot - .root() - .map(ToOwned::to_owned) - .map(RustLibSource::Path); - - let target_dir = self - .cargo_target_dir - .clone() - .unwrap_or_else(|| self.scratch_dir.join("target")); - let target_dir = Utf8PathBuf::from_path_buf(target_dir).ok(); - - let features = if self.cargo_features.is_empty() { - Default::default() - } else if self.cargo_features.contains(&"*".to_string()) { - CargoFeatures::All - } else { - CargoFeatures::Selected { - features: self.cargo_features.clone(), - no_default_features: false, + fn sysroot(&self, dir: &AbsPath) -> Sysroot { + let sysroot_input = self.sysroot.as_ref().map(|p| join_path_buf(dir, p)); + let sysroot_src_input = self.sysroot_src.as_ref().map(|p| join_path_buf(dir, p)); + match (sysroot_input, sysroot_src_input) { + (None, None) => Sysroot::discover(dir, &self.cargo_extra_env), + (Some(sysroot), None) => Sysroot::discover_sysroot_src_dir(sysroot), + (None, Some(sysroot_src)) => { + Sysroot::discover_with_src_override(dir, &self.cargo_extra_env, sysroot_src) } - }; + (Some(sysroot), Some(sysroot_src)) => Sysroot::new(Some(sysroot), Some(sysroot_src)), + } + } - let target = self.cargo_target.clone(); + fn proc_macro_server_choice(&self, dir: &AbsPath) -> ProcMacroServerChoice { + match &self.proc_macro_server { + Some(path) => match path.to_str() { + Some("none") => ProcMacroServerChoice::None, + Some("sysroot") => ProcMacroServerChoice::Sysroot, + _ => ProcMacroServerChoice::Explicit(join_path_buf(dir, path)), + }, + None => ProcMacroServerChoice::Sysroot, + } + } - let cfg_overrides = to_cfg_overrides(&self.cargo_cfg_overrides); + pub fn to_cargo_config(&self, dir: &AbsPath) -> (CargoConfig, LoadCargoConfig) { + let sysroot = self.sysroot(dir); + ( + CargoConfig { + all_targets: self.cargo_all_targets, + sysroot_src: sysroot.src_root().map(ToOwned::to_owned), + rustc_source: self + .rustc_src + .as_ref() + .map(|p| join_path_buf(dir, p)) + .or_else(|| sysroot.discover_rustc_src().map(AbsPathBuf::from)) + .map(RustLibSource::Path), + sysroot: sysroot + .root() + .map(ToOwned::to_owned) + .map(RustLibSource::Path), - CargoConfig { - sysroot, - sysroot_src, - target_dir, - features, - target, - cfg_overrides, - ..Default::default() - } + extra_env: self.cargo_extra_env.clone(), + extra_args: self.cargo_extra_args.clone(), + extra_includes: self + .extra_includes + .iter() + .map(|p| join_path_buf(dir, p)) + .collect(), + target_dir: Utf8PathBuf::from_path_buf( + self.cargo_target_dir + .clone() + .unwrap_or_else(|| self.scratch_dir.join("target")), + ) + .ok(), + features: if self.cargo_features.is_empty() { + Default::default() + } else if self.cargo_features.contains(&"*".to_string()) { + CargoFeatures::All + } else { + CargoFeatures::Selected { + features: self.cargo_features.clone(), + no_default_features: false, + } + }, + target: self.cargo_target.clone(), + cfg_overrides: to_cfg_overrides(&self.cargo_cfg_overrides), + wrap_rustc_in_build_scripts: false, + run_build_script_command: if self.build_script_command.is_empty() { + None + } else { + Some(self.build_script_command.clone()) + }, + ..Default::default() + }, + LoadCargoConfig { + load_out_dirs_from_check: true, + with_proc_macro_server: self.proc_macro_server_choice(dir), + prefill_caches: false, + }, + ) } } @@ -168,3 +219,10 @@ fn to_cfg_overrides(specs: &Vec) -> CfgOverrides { ..Default::default() } } + +fn join_path_buf(lhs: &AbsPath, rhs: &Path) -> AbsPathBuf { + let Ok(path) = Utf8PathBuf::from_path_buf(rhs.into()) else { + panic!("non utf8 input: {}", rhs.display()) + }; + lhs.join(path) +} diff --git a/rust/extractor/src/config/deserialize.rs b/rust/extractor/src/config/deserialize.rs new file mode 100644 index 000000000000..5953acd86057 --- /dev/null +++ b/rust/extractor/src/config/deserialize.rs @@ -0,0 +1,97 @@ +use serde::de::{Error, Unexpected, Visitor}; +use serde::Deserializer; +use std::collections::HashMap; +use std::fmt::Formatter; +use std::hash::BuildHasher; +use std::marker::PhantomData; + +// phantom data is required to allow parametrizing on `T` without actual `T` data +struct VectorVisitor>(PhantomData); +struct MapVisitor(PhantomData); + +impl<'de, T: From> Visitor<'de> for VectorVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("either a sequence, or a comma or newline separated string") + } + + fn visit_str(self, value: &str) -> Result, E> { + Ok(value + .split(['\n', ',']) + .map(|s| T::from(s.to_owned())) + .collect()) + } + + fn visit_seq(self, mut seq: A) -> Result, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + let mut ret = Vec::new(); + while let Some(el) = seq.next_element::()? { + ret.push(T::from(el)); + } + Ok(ret) + } +} + +impl<'de, S: BuildHasher + Default> Visitor<'de> for MapVisitor { + type Value = HashMap; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str( + "either a sequence, or a comma or newline separated string of key=value entries", + ) + } + + fn visit_str(self, value: &str) -> Result { + value + .split(['\n', ',']) + .map(|s| { + s.split_once('=') + .ok_or_else(|| E::custom(format!("key=value expected, found {s}"))) + .map(|(key, value)| (key.to_owned(), value.to_owned())) + }) + .collect() + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut ret = HashMap::with_hasher(Default::default()); + while let Some(el) = seq.next_element::()? { + let (key, value) = el + .split_once('=') + .ok_or_else(|| A::Error::invalid_value(Unexpected::Str(&el), &self))?; + ret.insert(key.to_owned(), value.to_owned()); + } + Ok(ret) + } +} + +/// deserialize into a vector of `T` either of: +/// * a sequence of elements serializable into `String`s, or +/// * a single element serializable into `String`, then split on `,` and `\n` +pub(crate) fn deserialize_newline_or_comma_separated_vec< + 'a, + D: Deserializer<'a>, + T: From, +>( + deserializer: D, +) -> Result, D::Error> { + deserializer.deserialize_seq(VectorVisitor(PhantomData)) +} + +/// deserialize into a map of `String`s to `String`s either of: +/// * a sequence of elements serializable into `String`s, or +/// * a single element serializable into `String`, then split on `,` and `\n` +pub(crate) fn deserialize_newline_or_comma_separated_map< + 'a, + D: Deserializer<'a>, + S: BuildHasher + Default, +>( + deserializer: D, +) -> Result, D::Error> { + deserializer.deserialize_map(MapVisitor(PhantomData)) +} diff --git a/rust/extractor/src/config/deserialize_vec.rs b/rust/extractor/src/config/deserialize_vec.rs deleted file mode 100644 index 2d63046549b9..000000000000 --- a/rust/extractor/src/config/deserialize_vec.rs +++ /dev/null @@ -1,50 +0,0 @@ -use serde::de::Visitor; -use serde::Deserializer; -use std::fmt::Formatter; -use std::marker::PhantomData; - -// phantom data ise required to allow parametrizing on `T` without actual `T` data -struct VectorVisitor>(PhantomData); - -impl> VectorVisitor { - fn new() -> Self { - VectorVisitor(PhantomData) - } -} - -impl<'de, T: From> Visitor<'de> for VectorVisitor { - type Value = Vec; - - fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { - formatter.write_str("either a sequence, or a comma or newline separated string") - } - - fn visit_str(self, value: &str) -> Result, E> { - Ok(value - .split(['\n', ',']) - .map(|s| T::from(s.to_owned())) - .collect()) - } - - fn visit_seq(self, mut seq: A) -> Result, A::Error> - where - A: serde::de::SeqAccess<'de>, - { - let mut ret = Vec::new(); - while let Some(el) = seq.next_element::()? { - ret.push(T::from(el)); - } - Ok(ret) - } -} - -/// deserialize into a vector of `T` either of: -/// * a sequence of elements serializable into `String`s, or -/// * a single element serializable into `String`, then split on `,` and `\n` -/// -/// This is required to be in scope when the `extractor_cli_config` macro is used. -pub(crate) fn deserialize_newline_or_comma_separated<'a, D: Deserializer<'a>, T: From>( - deserializer: D, -) -> Result, D::Error> { - deserializer.deserialize_seq(VectorVisitor::new()) -} diff --git a/rust/extractor/src/main.rs b/rust/extractor/src/main.rs index 04aaf23c652e..48445a935c30 100644 --- a/rust/extractor/src/main.rs +++ b/rust/extractor/src/main.rs @@ -6,6 +6,7 @@ use archive::Archiver; use ra_ap_hir::Semantics; use ra_ap_ide_db::line_index::{LineCol, LineIndex}; use ra_ap_ide_db::RootDatabase; +use ra_ap_load_cargo::LoadCargoConfig; use ra_ap_paths::{AbsPathBuf, Utf8PathBuf}; use ra_ap_project_model::{CargoConfig, ProjectManifest}; use ra_ap_vfs::Vfs; @@ -114,9 +115,10 @@ impl<'a> Extractor<'a> { &mut self, project: &ProjectManifest, config: &CargoConfig, + load_config: &LoadCargoConfig, ) -> Option<(RootDatabase, Vfs)> { let before = Instant::now(); - let ret = RustAnalyzer::load_workspace(project, config); + let ret = RustAnalyzer::load_workspace(project, config, load_config); self.steps .push(ExtractionStep::load_manifest(before, project)); ret @@ -235,9 +237,12 @@ fn main() -> anyhow::Result<()> { } extractor.extract_without_semantics(file, "no manifest found"); } - let cargo_config = cfg.to_cargo_config(&cwd()?); + let cwd = cwd()?; + let (cargo_config, load_cargo_config) = cfg.to_cargo_config(&cwd); for (manifest, files) in map.values().filter(|(_, files)| !files.is_empty()) { - if let Some((ref db, ref vfs)) = extractor.load_manifest(manifest, &cargo_config) { + if let Some((ref db, ref vfs)) = + extractor.load_manifest(manifest, &cargo_config, &load_cargo_config) + { let semantics = Semantics::new(db); for file in files { match extractor.load_source(file, &semantics, vfs) { diff --git a/rust/extractor/src/rust_analyzer.rs b/rust/extractor/src/rust_analyzer.rs index ac139d68e12b..2ebbcac6b590 100644 --- a/rust/extractor/src/rust_analyzer.rs +++ b/rust/extractor/src/rust_analyzer.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use ra_ap_base_db::SourceDatabase; use ra_ap_hir::Semantics; use ra_ap_ide_db::RootDatabase; -use ra_ap_load_cargo::{load_workspace_at, LoadCargoConfig, ProcMacroServerChoice}; +use ra_ap_load_cargo::{load_workspace_at, LoadCargoConfig}; use ra_ap_paths::{AbsPath, Utf8PathBuf}; use ra_ap_project_model::ProjectManifest; use ra_ap_project_model::{CargoConfig, ManifestPath}; @@ -50,16 +50,12 @@ impl<'a> RustAnalyzer<'a> { pub fn load_workspace( project: &ProjectManifest, config: &CargoConfig, + load_config: &LoadCargoConfig, ) -> Option<(RootDatabase, Vfs)> { let progress = |t| (trace!("progress: {}", t)); - let load_config = LoadCargoConfig { - load_out_dirs_from_check: true, - with_proc_macro_server: ProcMacroServerChoice::Sysroot, - prefill_caches: false, - }; let manifest = project.manifest_path(); - match load_workspace_at(manifest.as_ref(), config, &load_config, &progress) { + match load_workspace_at(manifest.as_ref(), config, load_config, &progress) { Ok((db, vfs, _macro_server)) => Some((db, vfs)), Err(err) => { error!("failed to load workspace for {}: {}", manifest, err);