diff --git a/one_file_ref.py b/one_file_ref.py index 61364af..b9c5824 100644 --- a/one_file_ref.py +++ b/one_file_ref.py @@ -96,16 +96,16 @@ def __init__(self, args: ModelArgs): args.sliding_window, self.n_kv_heads, self.args.head_dim, - ), dtype=torch.float16 - ).cuda() + ) + ) self.cache_v = torch.empty( ( args.max_batch_size, args.sliding_window, self.n_kv_heads, self.args.head_dim, - ), dtype=torch.float16 - ).cuda() + ) + ) def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor] @@ -236,7 +236,7 @@ def __init__(self, args: ModelArgs): bias=False ) - self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to("cuda") + self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to(args.device) def forward( @@ -267,11 +267,16 @@ def forward( return self.output(self.norm(h)).float() @staticmethod - def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16): + def from_folder(folder: Path, max_batch_size: int = 1, device: str="cuda"): with open(folder / 'params.json', 'r') as f: model_args = ModelArgs(**json.loads(f.read())) + if device == "cuda": + torch.set_default_tensor_type(torch.cuda.HalfTensor) + else: + torch.set_default_tensor_type(torch.BFloat16Tensor) model_args.max_batch_size = max_batch_size - model = Transformer(model_args).to(device=device, dtype=dtype) + model_args.device = device + model = Transformer(model_args).to(device=device) loaded = torch.load(folder / 'consolidated.00.pth') model.load_state_dict(loaded) return model @@ -300,18 +305,19 @@ def decode(self, t: List[int]) -> str: @torch.no_grad() def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int): + device = model.args.device encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts] prompt_lens = [len(x) for x in encoded_prompts] min_prompt_len = min(prompt_lens) max_prompt_len = max(prompt_lens) - - input_tokens = torch.full((len(prompts), max_prompt_len), tokenizer.pad_id, dtype=torch.long, device="cuda") + + input_tokens = torch.full((len(prompts), max_prompt_len), tokenizer.pad_id, dtype=torch.long, device=device) for i, encoded in enumerate(encoded_prompts): input_tokens[i, :len(encoded)] = torch.tensor(encoded).to(input_tokens) input_mask = input_tokens != tokenizer.pad_id # pre-fill - positions = torch.arange(0, min_prompt_len).to("cuda") + positions = torch.arange(0, min_prompt_len).to(device) logits = model.forward(input_tokens[:, :min_prompt_len], positions) logprobs = nn.functional.log_softmax(logits, dim=-1) @@ -345,7 +351,8 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_t def demo(model_path: str, max_tokens: int = 35): tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) - transformer = Transformer.from_folder(Path(model_path), max_batch_size=3) + device = "cuda" if torch.cuda.is_available() else "cpu" + transformer = Transformer.from_folder(Path(model_path), max_batch_size = 3, device = device) res, _logprobs = generate( [ @@ -362,4 +369,4 @@ def demo(model_path: str, max_tokens: int = 35): print("=====================") if __name__ == "__main__": - fire.Fire(demo) \ No newline at end of file + fire.Fire(demo)