Skip to content

Commit

Permalink
Make GPU dependencies optional (#27)
Browse files Browse the repository at this point in the history
* Move GPU imports and make them optional

Signed-off-by: Ayush Dattagupta <[email protected]>

* Move gpu dependencies to a seperate install

Signed-off-by: Ayush Dattagupta <[email protected]>

* Remove unused import

Signed-off-by: Ayush Dattagupta <[email protected]>

* Switch to placeholder import that raises on usage

Signed-off-by: Ayush Dattagupta <[email protected]>

* Remove deprecated utils usage

Signed-off-by: Ayush Dattagupta <[email protected]>

* Add cuML attribution

Signed-off-by: Ayush Dattagupta <[email protected]>

* Safe import tests, improve install instruction, update gha workflow

Signed-off-by: Ayush Dattagupta <[email protected]>

* Fix pytests due to loc bug

Signed-off-by: Ayush Dattagupta <[email protected]>

* update install instructions

Signed-off-by: Ayush Dattagupta <[email protected]>

* Raise on non module-not-found errors, update logging

Signed-off-by: Ayush Dattagupta <[email protected]>

* Update logging to not change root logger

Signed-off-by: Ayush Dattagupta <[email protected]>

---------

Signed-off-by: Ayush Dattagupta <[email protected]>
  • Loading branch information
ayushdg authored Apr 23, 2024
1 parent 9864988 commit 17e0d5f
Show file tree
Hide file tree
Showing 20 changed files with 493 additions and 141 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ jobs:
# Explicitly install cython: https://github.com/VKCOM/YouTokenToMe/issues/94
run: |
pip install wheel cython
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com .
pip install --no-cache-dir .
pip install pytest
- name: Run tests
# TODO: Remove env variable when gpu dependencies are optional
run: |
RAPIDS_NO_INITIALIZE=1 python -m pytest -v --cpu
python -m pytest -v --cpu
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,20 @@ These modules are designed to be flexible and allow for reordering with few exce

## Installation

NeMo Curator currently requires Python 3.10 and a GPU with CUDA 12 or above installed in order to be used.
NeMo Curator currently requires Python 3.10 and the GPU accelerated modules require CUDA 12 or above installed in order to be used.

NeMo Curator can be installed manually by cloning the repository and installing as follows:
NeMo Curator can be installed manually by cloning the repository and installing as follows -

For CPU only modules:
```
pip install .
```
pip install --extra-index-url https://pypi.nvidia.com .

For CPU + CUDA accelerated modules
```
pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"
```

### NeMo Framework Container

NeMo Curator is available in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo). The NeMo Framework Container provides an end-to-end platform for development of custom generative AI models anywhere. The latest release of NeMo Curator comes preinstalled in the container.
Expand Down
6 changes: 1 addition & 5 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import dask.dataframe as dd
import dask_cudf

from nemo_curator.utils.distributed_utils import read_data, write_to_disk
from nemo_curator.utils.file_utils import get_all_files_paths_under
Expand Down Expand Up @@ -182,10 +181,7 @@ def _read_json_or_parquet(
)
dfs.append(df)

if backend == "cudf":
raw_data = dask_cudf.concat(dfs, ignore_unknown_divisions=True)
else:
raw_data = dd.concat(dfs, ignore_unknown_divisions=True)
raw_data = dd.concat(dfs, ignore_unknown_divisions=True)

elif isinstance(input_files, str):
# Single file
Expand Down
76 changes: 0 additions & 76 deletions nemo_curator/gpu_deduplication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,84 +13,8 @@
# limitations under the License.

import argparse
import logging
import os
import socket
from contextlib import nullcontext
from time import time

import cudf
from dask_cuda import LocalCUDACluster
from distributed import Client, performance_report


def create_logger(rank, log_file, name="logger", log_level=logging.INFO):
# Create the logger
logger = logging.getLogger(name)
logger.setLevel(log_level)

myhost = socket.gethostname()

extra = {"host": myhost, "rank": rank}
formatter = logging.Formatter(
"%(asctime)s | %(host)s | Rank %(rank)s | %(message)s"
)

# File handler for output
file_handler = logging.FileHandler(log_file, mode="a")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger = logging.LoggerAdapter(logger, extra)

return logger


# TODO: Remove below to use nemo_curator.distributed_utils.get_client
def get_client(args) -> Client:
if args.scheduler_address:
if args.scheduler_file:
raise ValueError(
"Only one of scheduler_address or scheduler_file can be provided"
)
else:
return Client(address=args.scheduler_address, timeout="30s")
elif args.scheduler_file:
return Client(scheduler_file=args.scheduler_file, timeout="30s")
else:
extra_kwargs = (
{
"enable_tcp_over_ucx": True,
"enable_nvlink": True,
"enable_infiniband": False,
"enable_rdmacm": False,
}
if args.nvlink_only and args.protocol == "ucx"
else {}
)

cluster = LocalCUDACluster(
rmm_pool_size=args.rmm_pool_size,
protocol=args.protocol,
rmm_async=True,
**extra_kwargs,
)
return Client(cluster)


def performance_report_if(path=None, report_name="dask-profile.html"):
if path is not None:
return performance_report(os.path.join(path, report_name))
else:
return nullcontext()


# TODO: Remove below to use nemo_curator.distributed_utils._enable_spilling
def enable_spilling():
"""
Enables spilling to host memory for cudf
"""
cudf.set_option("spill", True)


def get_num_workers(client):
"""
Expand Down
7 changes: 6 additions & 1 deletion nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@
# See https://github.com/NVIDIA/NeMo-Curator/issues/31
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

from nemo_curator.utils.import_utils import gpu_only_import_from

from .add_id import AddId
from .exact_dedup import ExactDuplicates
from .filter import Filter, Score, ScoreFilter
from .fuzzy_dedup import LSH, MinHash
from .meta import Sequential
from .modify import Modify
from .task import TaskDecontamination

# GPU packages
LSH = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup", "LSH")
MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup", "MinHash")

# Pytorch related imports must come after all imports that require cugraph,
# because of context cleanup issues b/w pytorch and cugraph
# See this issue: https://github.com/rapidsai/cugraph/issues/2718
Expand Down
3 changes: 2 additions & 1 deletion nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator.datasets import DocumentDataset
from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if
from nemo_curator.utils.gpu_utils import is_cudf_type


Expand Down
15 changes: 8 additions & 7 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from typing import List, Tuple, Union

import cudf
import cugraph
import cugraph.dask as dcg
import cugraph.dask.comms.comms as Comms
import cupy as cp
import dask_cudf
import numpy as np
from cugraph import MultiGraph
from dask import dataframe as dd
from dask.dataframe.shuffle import shuffle as dd_shuffle
from dask.utils import M
Expand All @@ -39,12 +39,13 @@
filter_text_rows_by_bucket_batch,
merge_left_to_shuffled_right,
)
from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if
from nemo_curator.utils.distributed_utils import get_current_client, get_num_workers
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import (
convert_str_id_to_int,
int_ids_to_str,
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import (
get_current_client,
get_num_workers,
performance_report_if,
)
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
aggregated_anchor_docs_with_bk_read,
get_restart_offsets,
Expand Down Expand Up @@ -1120,7 +1121,7 @@ def _run_connected_components(
df = df[[self.left_id, self.right_id]].astype(np.int64)
df = dask_cudf.concat([df, self_edge_df])

G = cugraph.MultiGraph(directed=False)
G = MultiGraph(directed=False)
G.from_dask_cudf_edgelist(
df, source=self.left_id, destination=self.right_id, renumber=False
)
Expand Down
9 changes: 5 additions & 4 deletions nemo_curator/scripts/compute_minhashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from nemo_curator import MinHash
from nemo_curator.datasets import DocumentDataset
from nemo_curator.gpu_deduplication.ioutils import strip_trailing_sep
from nemo_curator.gpu_deduplication.utils import (
create_logger,
parse_nc_args,
from nemo_curator.gpu_deduplication.utils import parse_nc_args
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import (
get_client,
performance_report_if,
read_data,
)
from nemo_curator.utils.distributed_utils import get_client, read_data
from nemo_curator.utils.file_utils import get_all_files_paths_under


Expand Down
7 changes: 4 additions & 3 deletions nemo_curator/scripts/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import time

from nemo_curator.gpu_deduplication.utils import enable_spilling, parse_nc_args
from nemo_curator.gpu_deduplication.utils import parse_nc_args
from nemo_curator.modules.fuzzy_dedup import ConnectedComponents
from nemo_curator.utils.distributed_utils import get_client

Expand All @@ -32,9 +32,10 @@ def main(args):
st = time.time()
output_path = os.path.join(args.output_dir, "connected_components.parquet")
args.set_torch_to_use_rmm = False
args.enable_spilling = True

client = get_client(args, cluster_type="gpu")
enable_spilling()
client.run(enable_spilling)

components_stage = ConnectedComponents(
cache_dir=args.cache_dir,
jaccard_pairs_path=args.jaccard_pairs_path,
Expand Down
3 changes: 2 additions & 1 deletion nemo_curator/scripts/find_exact_duplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from nemo_curator.datasets import DocumentDataset
from nemo_curator.gpu_deduplication.ioutils import strip_trailing_sep
from nemo_curator.gpu_deduplication.utils import create_logger, parse_nc_args
from nemo_curator.gpu_deduplication.utils import parse_nc_args
from nemo_curator.log import create_logger
from nemo_curator.modules import ExactDuplicates
from nemo_curator.utils.distributed_utils import get_client, read_data
from nemo_curator.utils.file_utils import get_all_files_paths_under
Expand Down
8 changes: 4 additions & 4 deletions nemo_curator/scripts/jaccard_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import os
import time

from nemo_curator.gpu_deduplication.utils import enable_spilling, parse_nc_args
from nemo_curator.gpu_deduplication.utils import parse_nc_args
from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity
from nemo_curator.utils.distributed_utils import get_client, get_num_workers


def main(args):
description = """Computes the Jaccard similarity between document pairs
"""Computes the Jaccard similarity between document pairs
from partitioned parquet dataset. Result is a parquet dataset consiting of
document id pair along with their Jaccard similarity score.
"""
Expand All @@ -30,9 +30,9 @@ def main(args):
output_final_results_path = os.path.join(
OUTPUT_PATH, "jaccard_similarity_results.parquet"
)
args.enable_spilling = True
client = get_client(args, "gpu")
enable_spilling()
client.run(enable_spilling)

print(f"Num Workers = {get_num_workers(client)}", flush=True)
print("Connected to dask cluster", flush=True)
print("Running jaccard compute script", flush=True)
Expand Down
9 changes: 3 additions & 6 deletions nemo_curator/scripts/jaccard_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
import os
import time

from nemo_curator.gpu_deduplication.utils import (
get_client,
get_num_workers,
parse_nc_args,
)
from nemo_curator.gpu_deduplication.utils import get_num_workers, parse_nc_args
from nemo_curator.modules.fuzzy_dedup import _Shuffle
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
get_text_ddf_from_json_path_with_blocksize,
)
Expand All @@ -38,7 +35,7 @@ def main(args):
OUTPUT_PATH = args.output_dir
output_shuffled_docs_path = os.path.join(OUTPUT_PATH, "shuffled_docs.parquet")

client = get_client(args)
client = get_client(args, "gpu")
client.run(func)
print(f"Num Workers = {get_num_workers(client)}", flush=True)
print("Connected to dask cluster", flush=True)
Expand Down
9 changes: 3 additions & 6 deletions nemo_curator/scripts/map_buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
import os
import time

from nemo_curator.gpu_deduplication.utils import (
get_client,
get_num_workers,
parse_nc_args,
)
from nemo_curator.gpu_deduplication.utils import get_num_workers, parse_nc_args
from nemo_curator.modules.fuzzy_dedup import _MapBuckets
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
get_bucket_ddf_from_parquet_path,
get_text_ddf_from_json_path_with_blocksize,
Expand Down Expand Up @@ -157,7 +154,7 @@ def main(args):
output_anchor_docs_with_bk_path = os.path.join(
OUTPUT_PATH, "anchor_docs_with_bk.parquet"
)
client = get_client(args)
client = get_client(args, "gpu")
print(f"Num Workers = {get_num_workers(client)}", flush=True)
print("Connected to dask cluster", flush=True)
print("Running jaccard map buckets script", flush=True)
Expand Down
3 changes: 2 additions & 1 deletion nemo_curator/scripts/minhash_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import (
convert_str_id_to_int,
)
from nemo_curator.gpu_deduplication.utils import create_logger, parse_nc_args
from nemo_curator.gpu_deduplication.utils import parse_nc_args
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import get_client


Expand Down
Loading

0 comments on commit 17e0d5f

Please sign in to comment.