Skip to content

Commit

Permalink
feat: pass environment variables during pypi resolution and install (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nichmor authored Feb 23, 2024
1 parent 359f169 commit 37882ae
Show file tree
Hide file tree
Showing 14 changed files with 299 additions and 219 deletions.
1 change: 1 addition & 0 deletions examples/pypi/activate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export MY_SUPER_ENV=hello
1 change: 1 addition & 0 deletions examples/pypi/env_setup.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SET MY_SUPER_ENV=test
347 changes: 179 additions & 168 deletions examples/pypi/pixi.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions examples/pypi/pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ authors = ["Bas Zalmstra <[email protected]>"]
channels = ["conda-forge"]
platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"]


[activation]
scripts = ["activate.sh"]

[target.win-64.activation]
scripts = ["env_setup.bat"]

[tasks]
start = "python pycosat_example.py"
test = "python pycosat_example.py"
Expand All @@ -25,6 +32,7 @@ black = {version = "~=23.10", extras = ["jupyter"]}
pyliblzfse = "*"
pycosat = "*"
plot-antenna = "==1.7"
env_test_package = "==0.0.3"

[system-requirements]
# Tensorflow on macOS arm64 requires macOS 12.0 or higher
Expand Down
19 changes: 3 additions & 16 deletions src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use indexmap::IndexMap;
use std::collections::HashMap;

use itertools::Itertools;
use miette::{Context, IntoDiagnostic};
use miette::IntoDiagnostic;
use rattler_conda_types::Platform;
use rattler_shell::{
activation::{ActivationError, ActivationVariables, Activator, PathModificationBehavior},
Expand All @@ -11,7 +11,6 @@ use rattler_shell::{

use crate::{
environment::{get_up_to_date_prefix, LockFileUsage},
progress::await_in_progress,
project::{manifest::EnvironmentName, Environment},
Project,
};
Expand Down Expand Up @@ -180,23 +179,11 @@ pub fn get_environment_variables<'p>(environment: &'p Environment<'p>) -> HashMa
pub async fn get_activation_env<'p>(
environment: &'p Environment<'p>,
lock_file_usage: LockFileUsage,
) -> miette::Result<HashMap<String, String>> {
) -> miette::Result<&HashMap<String, String>> {
// Get the prefix which we can then activate.
get_up_to_date_prefix(environment, lock_file_usage, false, IndexMap::default()).await?;

// Get environment variables from the activation
let activation_env =
await_in_progress("activating environment", |_| run_activation(environment))
.await
.wrap_err("failed to activate environment")?;

let environment_variables = get_environment_variables(environment);

// Construct command environment by concatenating the environments
Ok(activation_env
.into_iter()
.chain(environment_variables.into_iter())
.collect())
environment.project().get_env_variables(environment).await
}

#[cfg(test)]
Expand Down
12 changes: 6 additions & 6 deletions src/cli/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,18 @@ pub async fn execute(args: Args) -> miette::Result<()> {
#[cfg(target_family = "unix")]
let res = match interactive_shell {
ShellEnum::NuShell(nushell) => {
start_nu_shell(nushell, &env, prompt::get_nu_prompt(prompt_name.as_str())).await
start_nu_shell(nushell, env, prompt::get_nu_prompt(prompt_name.as_str())).await
}
ShellEnum::PowerShell(pwsh) => start_powershell(
pwsh,
&env,
env,
prompt::get_powershell_prompt(prompt_name.as_str()),
),
ShellEnum::Bash(bash) => {
start_unix_shell(
bash,
vec!["-l", "-i"],
&env,
env,
prompt::get_bash_prompt(prompt_name.as_str()),
)
.await
Expand All @@ -258,7 +258,7 @@ pub async fn execute(args: Args) -> miette::Result<()> {
start_unix_shell(
zsh,
vec!["-l", "-i"],
&env,
env,
prompt::get_zsh_prompt(prompt_name.as_str()),
)
.await
Expand All @@ -267,13 +267,13 @@ pub async fn execute(args: Args) -> miette::Result<()> {
start_unix_shell(
fish,
vec![],
&env,
env,
prompt::get_fish_prompt(prompt_name.as_str()),
)
.await
}
ShellEnum::Xonsh(xonsh) => {
start_unix_shell(xonsh, vec![], &env, prompt::get_xonsh_prompt()).await
start_unix_shell(xonsh, vec![], env, prompt::get_xonsh_prompt()).await
}
_ => {
miette::bail!("Unsupported shell: {:?}", interactive_shell)
Expand Down
4 changes: 3 additions & 1 deletion src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
consts, install, install_pypi,
lock_file::UpdateLockFileOptions,
prefix::Prefix,
progress::{self},
progress,
project::{
manifest::{EnvironmentName, SystemRequirements},
virtual_packages::verify_current_platform_has_required_virtual_packages,
Expand Down Expand Up @@ -180,6 +180,7 @@ pub async fn update_prefix_pypi(
status: &PythonStatus,
system_requirements: &SystemRequirements,
sdist_resolution: SDistResolution,
env_variables: HashMap<String, String>,
) -> miette::Result<()> {
// Remove python packages from a previous python distribution if the python version changed.
install_pypi::remove_old_python_distributions(prefix, platform, status)?;
Expand All @@ -200,6 +201,7 @@ pub async fn update_prefix_pypi(
status,
system_requirements,
sdist_resolution,
env_variables,
)
},
)
Expand Down
6 changes: 5 additions & 1 deletion src/install_pypi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub async fn update_python_distributions(
status: &PythonStatus,
system_requirements: &SystemRequirements,
sdist_resolution: SDistResolution,
env_variables: HashMap<String, String>,
) -> miette::Result<()> {
let Some(python_info) = status.current_info() else {
// No python interpreter in the environment, so there is nothing to do here.
Expand Down Expand Up @@ -121,6 +122,7 @@ pub async fn update_python_distributions(
compatible_tags,
resolve_options,
python_distributions_to_install.clone(),
env_variables,
);

// Remove python packages that need to be removed
Expand Down Expand Up @@ -240,6 +242,7 @@ fn stream_python_artifacts(
compatible_tags: Arc<WheelTags>,
resolve_options: Arc<ResolveOptions>,
packages_to_download: Vec<&CombinedPypiPackageData>,
env_variables: HashMap<String, String>,
) -> (
impl Stream<Item = miette::Result<(Option<String>, HashSet<Extra>, Wheel)>> + '_,
Option<ProgressBar>,
Expand Down Expand Up @@ -272,6 +275,7 @@ fn stream_python_artifacts(
let compatible_tags = compatible_tags.clone();
let resolve_options = resolve_options.clone();
let package_db = package_db.clone();
let env_variables = env_variables.clone();

async move {
// Determine the filename from the
Expand Down Expand Up @@ -325,7 +329,7 @@ fn stream_python_artifacts(
marker_environment,
Some(compatible_tags),
resolve_options.deref().clone(),
HashMap::default(),
env_variables,
)
.into_diagnostic()
.context("error in construction of WheelBuilder for `pypi-dependencies` installation")?;
Expand Down
5 changes: 4 additions & 1 deletion src/lock_file/pypi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use std::sync::Arc;
use std::{collections::HashMap, vec};

/// Resolve python packages for the specified project.
// TODO(nichita): extract in strunct passed args
#[allow(clippy::too_many_arguments)]
pub async fn resolve_dependencies<'db>(
package_db: Arc<PackageDb>,
dependencies: IndexMap<PackageName, Vec<PyPiRequirement>>,
Expand All @@ -25,6 +27,7 @@ pub async fn resolve_dependencies<'db>(
conda_packages: &[RepoDataRecord],
python_location: Option<&Path>,
sdist_resolution: SDistResolution,
env_variables: HashMap<String, String>,
) -> miette::Result<Vec<PinnedPackage>> {
if dependencies.is_empty() {
return Ok(vec![]);
Expand Down Expand Up @@ -99,7 +102,7 @@ pub async fn resolve_dependencies<'db>(
python_location,
..Default::default()
},
HashMap::default(),
env_variables.clone(),
)
.await
.wrap_err("failed to resolve `pypi-dependencies`, due to underlying error")?;
Expand Down
4 changes: 3 additions & 1 deletion src/lock_file/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rattler_conda_types::{GenericVirtualPackage, MatchSpec, Platform, RepoDataRe
use rattler_lock::{PackageHashes, PypiPackageData, PypiPackageEnvironmentData};
use rattler_solve::{resolvo, SolverImpl};
use rip::{index::PackageDb, resolve::solve_options::SDistResolution};
use std::{path::Path, sync::Arc};
use std::{collections::HashMap, path::Path, sync::Arc};

/// This function takes as input a set of dependencies and system requirements and returns a set of
/// locked packages.
Expand All @@ -28,6 +28,7 @@ pub async fn resolve_pypi(
pb: &ProgressBar,
python_location: Option<&Path>,
sdist_resolution: SDistResolution,
env_variables: HashMap<String, String>,
) -> miette::Result<LockedPypiPackages> {
// Solve python packages
pb.set_message("resolving pypi dependencies");
Expand All @@ -39,6 +40,7 @@ pub async fn resolve_pypi(
locked_conda_records,
python_location,
sdist_resolution,
env_variables,
)
.await?;

Expand Down
20 changes: 16 additions & 4 deletions src/lock_file/update.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{
config, consts, environment,
config, consts,
environment::{
LockFileUsage, PerEnvironmentAndPlatform, PerGroup, PerGroupAndPlatform, PythonStatus,
self, LockFileUsage, PerEnvironmentAndPlatform, PerGroup, PerGroupAndPlatform, PythonStatus,
},
load_lock_file, lock_file,
load_lock_file,
lock_file::{
update, OutdatedEnvironments, PypiPackageIdentifier, PypiRecordsByName,
self, update, OutdatedEnvironments, PypiPackageIdentifier, PypiRecordsByName,
RepoDataRecordsByName,
},
prefix::Prefix,
Expand Down Expand Up @@ -101,6 +101,8 @@ impl<'p> LockFileDerivedData<'p> {
.unwrap_or_default();
let pypi_records = self.pypi_records(environment, platform).unwrap_or_default();

let env_variables = environment.project().get_env_variables(environment).await?;

// Update the prefix with Pypi records
environment::update_prefix_pypi(
environment.name(),
Expand All @@ -112,6 +114,7 @@ impl<'p> LockFileDerivedData<'p> {
&python_status,
&environment.system_requirements(),
SDistResolution::default(),
env_variables.clone(),
)
.await?;

Expand Down Expand Up @@ -750,13 +753,17 @@ pub async fn ensure_up_to_date_lock_file(
.get_conda_prefix(&group)
.expect("prefix should be available now or in the future");

// Get environment variables from the activation
let env_variables = project.get_env_variables(&environment).await?;

// Spawn a task to solve the pypi environment
let pypi_solve_future = spawn_solve_pypi_task(
group.clone(),
platform,
repodata_future,
prefix_future,
SDistResolution::default(),
env_variables,
);

pending_futures.push(pypi_solve_future.boxed_local());
Expand Down Expand Up @@ -1228,6 +1235,7 @@ async fn spawn_solve_pypi_task(
repodata_records: impl Future<Output = Arc<RepoDataRecordsByName>>,
prefix: impl Future<Output = (Prefix, PythonStatus)>,
sdist_resolution: SDistResolution,
env_variables: &HashMap<String, String>,
) -> miette::Result<TaskResult> {
// Get the Pypi dependencies for this environment
let dependencies = environment.pypi_dependencies(Some(platform));
Expand All @@ -1250,6 +1258,9 @@ async fn spawn_solve_pypi_task(
let (repodata_records, (prefix, python_status)) = tokio::join!(repodata_records, prefix);

let environment_name = environment.name().clone();

let envs = env_variables.clone();

let (pypi_packages, duration) = tokio::spawn(
async move {
let pb = SolveProgressBar::new(
Expand All @@ -1274,6 +1285,7 @@ async fn spawn_solve_pypi_task(
.map(|path| prefix.root().join(path))
.as_deref(),
sdist_resolution,
envs,
)
.await?;

Expand Down
8 changes: 8 additions & 0 deletions src/project/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ impl<'p> PartialEq for Environment<'p> {
impl<'p> Eq for Environment<'p> {}

impl<'p> Environment<'p> {
/// Return new instance of Environment
pub fn new(project: &'p Project, environment: &'p manifest::Environment) -> Self {
Self {
project,
environment,
}
}

/// Returns true if this environment is the default environment.
pub fn is_default(&self) -> bool {
self.environment.name == EnvironmentName::Default
Expand Down
Loading

0 comments on commit 37882ae

Please sign in to comment.