Skip to content

Commit

Permalink
Merge branch 'master' into 6_fix_wd_config
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Oct 9, 2023
2 parents 5bba361 + 0855679 commit f327d6a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

import os
# DeepSpeed Team
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from torch.utils.data import Subset
import re

Expand All @@ -15,7 +16,9 @@ def __init__(self, output_path, seed, local_rank, dataset_name):
self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
if not dataset_name == 'local/jsonfile':
if os.path.exists(dataset_name):
self.raw_datasets = load_from_disk(dataset_name)
elif not dataset_name == 'local/jsonfile':
self.raw_datasets = load_dataset(dataset_name)

def get_train_data(self):
Expand Down
10 changes: 5 additions & 5 deletions applications/DeepSpeed-Chat/training/utils/model/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def __init__(self,
self.config.n_embd = self.config.hidden_size if hasattr(
self.config, "hidden_size") else self.config.n_embd
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.rwtranrsformer = base_model
self.rwtransformer = base_model
self.PAD_ID = tokenizer.pad_token_id
self.compute_fp32_loss = compute_fp32_loss

def gradient_checkpointing_enable(self):
self.rwtranrsformer.gradient_checkpointing_enable()
self.rwtransformer.gradient_checkpointing_enable()

def gradient_checkpointing_disable(self):
self.rwtranrsformer.gradient_checkpointing_disable()
self.rwtransformer.gradient_checkpointing_disable()

def forward(self,
input_ids=None,
Expand All @@ -54,7 +54,7 @@ def forward(self,
else:
kwargs = dict(head_mask=head_mask)

transformer_outputs = self.rwtranrsformer(
transformer_outputs = self.rwtransformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
Expand Down Expand Up @@ -140,7 +140,7 @@ def forward_value(self,
else:
kwargs = dict(head_mask=head_mask)

transformer_outputs = self.rwtranrsformer(
transformer_outputs = self.rwtransformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
Expand Down

0 comments on commit f327d6a

Please sign in to comment.