Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass CONDA_OVERRIDE_CUDA to with_cuda of conda-lock #721

Merged
merged 2 commits into from
Jan 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions conda-store-server/conda_store_server/action/generate_lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def print_cmd(cmd):
print_cmd(["conda", "config", "--show"])
print_cmd(["conda", "config", "--show-sources"])

# conda-lock ignores variables defined in the specification, so this code
# gets the value of CONDA_OVERRIDE_CUDA and passes it to conda-lock via
# the with_cuda parameter, see:
Comment on lines +40 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, so non blocking but since we said this is an "interim" solution it would be best to add also a TODO/something indicating that we should look at a more robust approach

# https://github.com/conda-incubator/conda-store/issues/719
# https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-virtual.html#overriding-detected-packages
# TODO: Support all variables once upstream fixes are made to conda-lock,
# see the discussion in issue 719.
if specification.variables is not None:
cuda_version = specification.variables.get("CONDA_OVERRIDE_CUDA")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CONDA_OVERRIDE_CUDA variable could come in as a string (expected by conda-lock) or a float.

The example from the description of this PR

...
variables:
  CONDA_OVERRIDE_CUDA: '12.0'

Will come through as a string and be forwarded to conda-lock, but the example in the original issue has:

...
variables:
  CONDA_OVERRIDE_CUDA: 11.8

Which would be parsed as float and be forwarded to conda-lock as a float.

I don't see a test update associated with this, so I am assuming we don't have any coverage here, could you try this out locally to see if we need to convert the value as we pull it off the specification?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a print to action_solve_lockfile to verify. It's always a string, even if you don't quote it. Tried with 12, 12.0, and 12.2.

(In fact, the admin UI renders previously unquoted versions in quotes once you submit an env.)

So no action is needed here.


You also asked offline whether using CONDA_OVERRIDE_CUDA on a machine without a GPU allows for using PyTorch with CUDA on a GPU-enabled machine later. I've tried the following:

  • On a machine with an Intel GPU (so no CUDA):
    • added the env from the top comment to conda-store, which has CONDA_OVERRIDE_CUDA: '12.0'
    • waited for it to generate the lockfile (but didn't wait for it to build)
    • downloaded the lockfile
  • On a machine with an NVIDIA GPU (and CUDA configured):
    • created an env from the downloaded lockfile via conda-lock install -n mytest1 --micromamba ~/mytest1.json
    • activated the env
    • confirmed that CUDA works:
>>> import torch
>>> torch.__file__
'/home/nkaretnikov/.conda/envs/mytest1/lib/python3.12/site-packages/torch/__init__.py'
>>> print(torch.version.cuda)
12.0
>>> x = torch.tensor(1, device='cuda')
>>> x.device
device(type='cuda', index=0)
>>> x
tensor(1, device='cuda:0')

Please let me know if you have any other questions!

else:
cuda_version = None

# CONDA_FLAGS is used by conda-lock in conda_solver.solve_specs_for_arch
try:
conda_flags_name = "CONDA_FLAGS"
Expand All @@ -48,6 +60,7 @@ def print_cmd(cmd):
platforms=platforms,
lockfile_path=lockfile_filename,
conda_exe=conda_command,
with_cuda=cuda_version,
)
finally:
os.environ.pop(conda_flags_name, None)
Expand Down
Loading