Skip to content

Commit

Permalink
test: running cuda env and using those tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruben-arts committed Feb 2, 2024
1 parent 7599763 commit fa04cca
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions src/task/task_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,22 +366,23 @@ pub enum TaskGraphError {
#[cfg(test)]
mod test {
use crate::task::task_graph::TaskGraph;
use crate::Project;
use crate::{EnvironmentName, Project};
use rattler_conda_types::Platform;
use std::path::Path;

fn commands_in_order(
project_str: &str,
run_args: &[&str],
platform: Option<Platform>,
environment_name: Option<EnvironmentName>,
) -> Vec<String> {
let project = Project::from_str(Path::new(""), project_str).unwrap();

let environment = environment_name.map(|name| project.environment(&name).unwrap());
let graph = TaskGraph::from_cmd_args(
&project,
run_args.into_iter().map(|arg| arg.to_string()).collect(),
platform,
None,
environment,
)
.unwrap();

Expand Down Expand Up @@ -409,6 +410,7 @@ mod test {
top = {cmd="echo top", depends_on=["task1","task2"]}
"#,
&["top", "--test"],
None,
None
),
vec!["echo root", "echo task1", "echo task2", "echo top --test"]
Expand All @@ -431,6 +433,7 @@ mod test {
top = {cmd="echo top", depends_on=["task1","task2"]}
"#,
&["top"],
None,
None
),
vec!["echo root", "echo task1", "echo task2", "echo top"]
Expand All @@ -456,6 +459,7 @@ mod test {
"#,
&["top"],
Some(Platform::Linux64),
None
),
vec!["echo linux", "echo task1", "echo task2", "echo top",]
);
Expand All @@ -473,6 +477,7 @@ mod test {
"#,
&["echo bla"],
None,
None
),
vec![r#""echo bla""#]
);
Expand All @@ -496,6 +501,7 @@ mod test {
"#,
&["build"],
None,
None
),
vec![r#"echo build"#]
);
Expand All @@ -522,8 +528,40 @@ mod test {
"#,
&["start"],
None,
None
),
vec![r#"hello world"#]
);
}

#[test]
fn test_multi_env_cuda() {
assert_eq!(
commands_in_order(
r#"
[project]
name = "pixi"
channels = ["conda-forge"]
platforms = ["linux-64"]
[tasks]
train = "python train.py"
test = "python test.py"
start = {depends_on = ["train", "test"]}
[feature.cuda.tasks]
train = "python train.py --cuda"
test = "python test.py --cuda"
[environments]
cuda = ["cuda"]
"#,
&["start"],
None,
Some(EnvironmentName::Named("cuda".to_string()))
),
vec![r#"python train.py --cuda"#, r#"python test.py --cuda"#]
);
}
}

0 comments on commit fa04cca

Please sign in to comment.