Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU support to one_file_ref.py (the one file implementation) #129

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions one_file_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def __init__(self, args: ModelArgs):
args.sliding_window,
self.n_kv_heads,
self.args.head_dim,
), dtype=torch.float16
).cuda()
)
)
self.cache_v = torch.empty(
(
args.max_batch_size,
args.sliding_window,
self.n_kv_heads,
self.args.head_dim,
), dtype=torch.float16
).cuda()
)
)

def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor]
Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(self, args: ModelArgs):
bias=False
)

self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to("cuda")
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to(args.device)


def forward(
Expand Down Expand Up @@ -267,11 +267,16 @@ def forward(
return self.output(self.norm(h)).float()

@staticmethod
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16):
def from_folder(folder: Path, max_batch_size: int = 1, device: str="cuda"):
with open(folder / 'params.json', 'r') as f:
model_args = ModelArgs(**json.loads(f.read()))
if device == "cuda":
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.BFloat16Tensor)
model_args.max_batch_size = max_batch_size
model = Transformer(model_args).to(device=device, dtype=dtype)
model_args.device = device
model = Transformer(model_args).to(device=device)
loaded = torch.load(folder / 'consolidated.00.pth')
model.load_state_dict(loaded)
return model
Expand Down Expand Up @@ -300,18 +305,19 @@ def decode(self, t: List[int]) -> str:

@torch.no_grad()
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
device = model.args.device
encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
prompt_lens = [len(x) for x in encoded_prompts]
min_prompt_len = min(prompt_lens)
max_prompt_len = max(prompt_lens)

input_tokens = torch.full((len(prompts), max_prompt_len), tokenizer.pad_id, dtype=torch.long, device="cuda")
input_tokens = torch.full((len(prompts), max_prompt_len), tokenizer.pad_id, dtype=torch.long, device=device)
for i, encoded in enumerate(encoded_prompts):
input_tokens[i, :len(encoded)] = torch.tensor(encoded).to(input_tokens)
input_mask = input_tokens != tokenizer.pad_id

# pre-fill
positions = torch.arange(0, min_prompt_len).to("cuda")
positions = torch.arange(0, min_prompt_len).to(device)
logits = model.forward(input_tokens[:, :min_prompt_len], positions)
logprobs = nn.functional.log_softmax(logits, dim=-1)

Expand Down Expand Up @@ -345,7 +351,8 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_t

def demo(model_path: str, max_tokens: int = 35):
tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer = Transformer.from_folder(Path(model_path), max_batch_size = 3, device = device)

res, _logprobs = generate(
[
Expand All @@ -362,4 +369,4 @@ def demo(model_path: str, max_tokens: int = 35):
print("=====================")

if __name__ == "__main__":
fire.Fire(demo)
fire.Fire(demo)