Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Jul 23, 2024
1 parent 28df120 commit 7bd69cb
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions nemo_curator/modules/distributed_data_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoConfig, AutoTokenizer, AutoModel
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers.models.deberta_v2 import DebertaV2TokenizerFast

from nemo_curator.datasets import DocumentDataset


DOMAIN_IDENTIFIER = "nvidia/domain-classifier"


Expand All @@ -54,7 +53,9 @@ def __init__(
super().__init__()
self.config = config
if config_path is None:
self.config = AutoConfig.from_pretrained(config.model, output_hidden_states=True)
self.config = AutoConfig.from_pretrained(
config.model, output_hidden_states=True
)
else:
self.config = torch.load(config_path)

Expand Down Expand Up @@ -108,7 +109,9 @@ def __init__(self, config):
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))

def _forward(self, batch):
features = self.model(batch["input_ids"], batch["attention_mask"]).last_hidden_state
features = self.model(
batch["input_ids"], batch["attention_mask"]
).last_hidden_state
dropped = self.dropout(features)
outputs = self.fc(dropped)
return torch.softmax(outputs[:, 0, :], dim=1)
Expand All @@ -119,7 +122,7 @@ def forward(self, batch):
return self._forward(batch)
else:
return self._forward(batch)

def set_autocast(self, autocast):
self.autocast = autocast

Expand Down

0 comments on commit 7bd69cb

Please sign in to comment.