-
Notifications
You must be signed in to change notification settings - Fork 209
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
misc: Add multiple machine example (#757)
This is a more exotic example of multi env usage --------- Co-authored-by: Wackyator <[email protected]>
- Loading branch information
1 parent
4389bfd
commit 251dd0f
Showing
7 changed files
with
8,755 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# GitHub syntax highlighting | ||
pixi.lock linguist-language=YAML |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# pixi environments | ||
.pixi |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
[project] | ||
name = "multi-machine" | ||
description = "A mock project that does ML stuff" | ||
channels = ["conda-forge", "pytorch"] | ||
# All platforms that are supported by the project as the features will take the intersection of the platforms defined there. | ||
platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] | ||
|
||
[tasks] | ||
train = "python train.py" | ||
test = "python test.py" | ||
start = {depends_on = ["train", "test"]} | ||
|
||
[dependencies] | ||
python = "3.11.*" | ||
pytorch = {version = ">=2.0.1", channel = "pytorch"} | ||
torchvision = {version = ">=0.15", channel = "pytorch"} | ||
polars = ">=0.20,<0.21" | ||
matplotlib-base = ">=3.8.2,<3.9" | ||
ipykernel = ">=6.28.0,<6.29" | ||
|
||
[feature.cuda] | ||
platforms = ["win-64", "linux-64"] | ||
channels = ["nvidia", {channel = "pytorch", priority = -1}] | ||
system-requirements = {cuda = "12.1"} | ||
|
||
[feature.cuda.tasks] | ||
train = "python train.py --cuda" | ||
test = "python test.py --cuda" | ||
|
||
[feature.cuda.dependencies] | ||
pytorch-cuda = {version = "12.1.*", channel = "pytorch"} | ||
|
||
[feature.mlx] | ||
platforms = ["osx-arm64"] | ||
system-requirements = {macos = "13.3"} | ||
|
||
[feature.mlx.tasks] | ||
train = "python train.py --mlx" | ||
test = "python test.py --mlx" | ||
|
||
[feature.mlx.dependencies] | ||
mlx = "*" | ||
|
||
|
||
[environments] | ||
cuda = ["cuda"] | ||
mlx = ["mlx"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import os | ||
import sys | ||
|
||
print("Hello from test.py!") | ||
print("Environment you are running on:") | ||
print(os.environ["PIXI_ENVIRONMENT_NAME"]) | ||
print("Arguments given to the script:") | ||
print(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import sys | ||
|
||
if os.environ["PIXI_ENVIRONMENT_NAME"] == "mlx": | ||
import mlx.core as mx | ||
a = mx.array([1, 2, 3, 4]) | ||
print(a.shape) | ||
print("MLX is available, in mlx environment as expected") | ||
|
||
if os.environ["PIXI_ENVIRONMENT_NAME"] == "cuda": | ||
import torch | ||
assert torch.cuda.is_available(), "CUDA is not available" | ||
print("CUDA is available, in cuda environment as expected") | ||
|
||
if os.environ["PIXI_ENVIRONMENT_NAME"] == "default": | ||
import torch | ||
assert not torch.cuda.is_available(), "CUDA is available, in default environment" | ||
print("CUDA is not available, in default environment as expected") | ||
|
||
print("\nHello from train.py!") | ||
print("Environment you are running on:") | ||
print(os.environ["PIXI_ENVIRONMENT_NAME"]) | ||
print("Arguments given to the script:") | ||
print(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters