Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check for correct platform in task env creation #759

Merged
merged 8 commits into from
Feb 13, 2024
20 changes: 20 additions & 0 deletions src/cli/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use std::convert::identity;
use std::str::FromStr;
use std::{collections::HashMap, path::PathBuf, string::String};

use crate::consts;
use clap::Parser;
use dialoguer::theme::ColorfulTheme;
use itertools::Itertools;
use miette::{miette, Context, Diagnostic};
use rattler_conda_types::Platform;

use crate::activation::get_environment_variables;
use crate::environment::verify_prefix_location_unchanged;
use crate::project::errors::UnsupportedPlatformError;
use crate::task::{
AmbiguousTask, ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory,
Expand All @@ -22,6 +24,7 @@ use crate::lock_file::LockFileDerivedData;
use crate::lock_file::UpdateLockFileOptions;
use crate::progress::await_in_progress;
use crate::project::manifest::EnvironmentName;
use crate::project::virtual_packages::verify_current_platform_has_required_virtual_packages;
use crate::project::Environment;
use thiserror::Error;
use tracing::Level;
Expand Down Expand Up @@ -50,6 +53,15 @@ pub async fn execute(args: Args) -> miette::Result<()> {
// Load the project
let project = Project::load_or_else_discover(args.manifest_path.as_deref())?;

// Sanity check of prefix location
verify_prefix_location_unchanged(
project
.default_environment()
.dir()
.join(consts::PREFIX_FILE_NAME)
.as_path(),
)?;
baszalmstra marked this conversation as resolved.
Show resolved Hide resolved

// Extract the passed in environment name.
let explicit_environment = args
.environment
Expand All @@ -62,6 +74,11 @@ pub async fn execute(args: Args) -> miette::Result<()> {
})
.transpose()?;

// Verify that the current platform has the required virtual packages for the environment.
if let Some(ref explicit_environment) = explicit_environment {
verify_current_platform_has_required_virtual_packages(explicit_environment)?;
}

// Ensure that the lock-file is up-to-date.
let mut lock_file = project
.up_to_date_lock_file(UpdateLockFileOptions {
Expand Down Expand Up @@ -92,6 +109,8 @@ pub async fn execute(args: Args) -> miette::Result<()> {

let task_graph = TaskGraph::from_cmd_args(&project, &search_environment, task_args)?;

tracing::info!("Task graph: {}", task_graph);

// Traverse the task graph in topological order and execute each individual task.
let mut task_idx = 0;
let mut task_envs = HashMap::new();
Expand Down Expand Up @@ -168,6 +187,7 @@ fn command_not_found<'p>(project: &'p Project, explicit_environment: Option<Envi
project
.environments()
.into_iter()
.filter(|env| verify_current_platform_has_required_virtual_packages(env).is_ok())
ruben-arts marked this conversation as resolved.
Show resolved Hide resolved
.flat_map(|env| {
env.tasks(Some(Platform::current()), true)
.into_iter()
Expand Down
11 changes: 11 additions & 0 deletions src/project/virtual_packages.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::manifest::{LibCSystemRequirement, SystemRequirements};
use crate::project::Environment;
use itertools::Itertools;
use miette::IntoDiagnostic;
use rattler_conda_types::{GenericVirtualPackage, Platform, Version};
use rattler_virtual_packages::{Archspec, Cuda, LibC, Linux, Osx, VirtualPackage};
Expand Down Expand Up @@ -104,6 +105,16 @@ pub fn verify_current_platform_has_required_virtual_packages(
) -> miette::Result<()> {
let current_platform = Platform::current();

// Is the current platform in the list of supported platforms?
if !environment.platforms().contains(&current_platform) {
return Err(miette::miette!(
"The current platform '{}' is not supported by the `{}` environment. Supported platforms: {}",
current_platform,
environment.name(),
environment.platforms().iter().map(|plat| plat.as_str()).join(", ")
));
}

let system_virtual_packages = VirtualPackage::current()
.into_diagnostic()?
.iter()
Expand Down
15 changes: 10 additions & 5 deletions src/task/task_environment.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::project::virtual_packages::verify_current_platform_has_required_virtual_packages;
use crate::project::Environment;
use crate::task::error::{AmbiguousTaskError, MissingTaskError};
use crate::task::TaskName;
Expand Down Expand Up @@ -128,6 +129,10 @@ impl<'p, D: TaskDisambiguation<'p>> SearchEnvironments<'p, D> {
.iter()
// Filter out default environment
.filter(|env| !env.name().is_default())
// Filter out environments that can not run on this machine.
.filter(|env| {
verify_current_platform_has_required_virtual_packages(env).is_ok()
})
.any(|env| {
if let Ok(task) = env.task(&name, self.platform) {
// If the task exists in the environment but it is not the reference to the same task, return true to make it ambiguous
Expand Down Expand Up @@ -202,7 +207,7 @@ mod tests {
[project]
name = "foo"
channels = ["foo"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64", "win-64", "osx-64"]

[tasks]
test = "cargo test"
Expand All @@ -224,7 +229,7 @@ mod tests {
[project]
name = "foo"
channels = ["foo"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64", "win-64", "osx-64"]

[tasks]
test = "cargo test"
Expand All @@ -247,7 +252,7 @@ mod tests {
[project]
name = "foo"
channels = ["foo"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64", "win-64", "osx-64"]

[tasks]
test = "pytest"
Expand Down Expand Up @@ -279,7 +284,7 @@ mod tests {
[project]
name = "foo"
channels = ["foo"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64", "win-64", "osx-64"]

[tasks]

Expand Down Expand Up @@ -314,7 +319,7 @@ mod tests {
[project]
name = "foo"
channels = ["foo"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64", "win-64", "osx-64"]

[tasks]
bla = "echo foo"
Expand Down
32 changes: 31 additions & 1 deletion src/task/task_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use crate::{
task::{error::MissingTaskError, CmdArgs, Custom, Task},
Project,
};
use itertools::Itertools;
use miette::Diagnostic;
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
env,
env, fmt,
ops::Index,
};
use thiserror::Error;
Expand All @@ -22,6 +23,7 @@ use thiserror::Error;
pub struct TaskId(usize);

/// A node in the [`TaskGraph`].
#[derive(Debug)]
pub struct TaskNode<'p> {
/// The name of the task or `None` if the task is a custom task.
pub name: Option<TaskName>,
Expand All @@ -38,6 +40,23 @@ pub struct TaskNode<'p> {
/// The id's of the task that this task depends on.
pub dependencies: Vec<TaskId>,
}
impl fmt::Display for TaskNode<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"task: {}, environment: {}, command: `{}`, additional arguments: `{}`, depends_on: `{}`",
self.name.clone().unwrap_or("CUSTOM COMMAND".into()).0,
self.run_environment.name(),
self.task.as_single_command().unwrap_or(Cow::Owned("".to_string())),
self.additional_args.join(", "),
self.dependencies
.iter()
.map(|id| id.0.to_string())
.collect::<Vec<String>>()
.join(", ")
)
}
}

impl<'p> TaskNode<'p> {
/// Returns the full command that should be executed for this task. This includes any
Expand All @@ -59,13 +78,24 @@ impl<'p> TaskNode<'p> {

/// A [`TaskGraph`] is a graph of tasks that defines the relationships between different executable
/// tasks.
#[derive(Debug)]
pub struct TaskGraph<'p> {
/// The project that this graph references
project: &'p Project,

/// The tasks in the graph
nodes: Vec<TaskNode<'p>>,
}
impl fmt::Display for TaskGraph<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"TaskGraph: number of nodes: {}, nodes: {}",
self.nodes.len(),
self.nodes.iter().format("\n")
)
}
}

impl<'p> Index<TaskId> for TaskGraph<'p> {
type Output = TaskNode<'p>;
Expand Down
Loading