Skip to content

Commit

Permalink
update verbose style
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Nov 24, 2023
1 parent 8a096dc commit 7390e06
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 145 deletions.
40 changes: 16 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ len of classes: 10
classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dtype of images: torch.float32
dtype of labels: torch.int64
range of image values: min=0.0, max=1.0



Expand Down Expand Up @@ -236,9 +237,9 @@ print(np.array(epoch_losses).round(2))

```
```python
[ 0.69 10.82 0.66 0.68 0.68]
[0.61 0.51 0.43 0.38 0.33]
[0.3 0.27 0.26 0.25 0.25]
[ 0.71 0.68 17.56 15.31 2.18]
[0.37 0.33 0.29 0.26 0.25]
[0.25 0.25 0.25 0.25 0.25]


```
Expand Down Expand Up @@ -296,18 +297,9 @@ print("Epoch_losses", np.array(epoch_losses).round(2))

```
```python
Epoch 1: -0.7875077128410339
Epoch 2: -0.8151381611824036
Epoch 3: -0.9291865229606628
Epoch 4: -1.9994704723358154
Epoch 5: -2.32348895072937
Epoch_losses [-0.79 -0.82 -0.93 -2. -2.32]

0%| | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 1.36it/s]100%|██████████| 1/1 [00:00<00:00, 1.36it/s]
0%| | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 581.17it/s]
0%| | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 1076.84it/s]
0%| | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 1198.37it/s]
0%| | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 1208.38it/s]
Epoch_losses [-154.51 -197.55 -217.29 -282.91 -318.66]

0%| | 0/5 [00:00<?, ?it/s]Loss: -154.51182556: 0%| | 0/5 [00:00<?, ?it/s]Loss: -154.51182556: 20%|██ | 1/5 [00:00<00:03, 1.33it/s]Loss: -197.54919434: 20%|██ | 1/5 [00:00<00:03, 1.33it/s]Loss: -217.29011536: 20%|██ | 1/5 [00:00<00:03, 1.33it/s]Loss: -282.90725708: 20%|██ | 1/5 [00:00<00:03, 1.33it/s]Loss: -318.66284180: 20%|██ | 1/5 [00:00<00:03, 1.33it/s]Loss: -318.66284180: 100%|██████████| 5/5 [00:00<00:00, 6.59it/s]

```

Expand Down Expand Up @@ -356,19 +348,19 @@ print(unraveled_params)
```python
Before
{'0.weight': Parameter containing:
tensor([[-0.1172, 0.1053, -0.3732],
[ 0.1407, 0.4086, 0.4325]], requires_grad=True), '0.bias': Parameter containing:
tensor([ 0.2276, -0.0026], requires_grad=True), '2.weight': Parameter containing:
tensor([[0.5320, 0.3535]], requires_grad=True), '2.bias': Parameter containing:
tensor([-0.7025], requires_grad=True)}
tensor([[-0.1981, 0.0046, 0.1901],
[-0.1083, 0.1330, -0.2079]], requires_grad=True), '0.bias': Parameter containing:
tensor([ 0.4904, -0.2374], requires_grad=True), '2.weight': Parameter containing:
tensor([[-0.5294, 0.1141]], requires_grad=True), '2.bias': Parameter containing:
tensor([-0.2028], requires_grad=True)}

After ravel
tensor([ 0.2276, -0.0026, -0.1172, 0.1053, -0.3732, 0.1407, 0.4086, 0.4325,
-0.7025, 0.5320, 0.3535], grad_fn=<CatBackward0>)
tensor([ 0.4904, -0.2374, -0.1981, 0.0046, 0.1901, -0.1083, 0.1330, -0.2079,
-0.2028, -0.5294, 0.1141], grad_fn=<CatBackward0>)

After unravel
{'0.weight': tensor([[-0.1172, 0.1053, -0.3732],
[ 0.1407, 0.4086, 0.4325]], grad_fn=<ViewBackward0>), '0.bias': tensor([ 0.2276, -0.0026], grad_fn=<ViewBackward0>), '2.weight': tensor([[0.5320, 0.3535]], grad_fn=<ViewBackward0>), '2.bias': tensor([-0.7025], grad_fn=<ViewBackward0>)}
{'0.weight': tensor([[-0.1981, 0.0046, 0.1901],
[-0.1083, 0.1330, -0.2079]], grad_fn=<ViewBackward0>), '0.bias': tensor([ 0.4904, -0.2374], grad_fn=<ViewBackward0>), '2.weight': tensor([[-0.5294, 0.1141]], grad_fn=<ViewBackward0>), '2.bias': tensor([-0.2028], grad_fn=<ViewBackward0>)}


```
Expand Down
10 changes: 6 additions & 4 deletions astra/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def get_batch():
inner_batch_size = batch_size

iterable = range(0, len(in_or_out), inner_batch_size)
if verbose:
iterable = tqdm(iterable)

if shuffle:
idx = torch.randperm(len(in_or_out))
Expand Down Expand Up @@ -111,7 +109,11 @@ def one_step(batch_input, batch_output):
epoch_losses = []
state_dict_list = []

for _ in range(epochs):
pbar = range(epochs)
if verbose:
pbar = tqdm(pbar)

for _ in pbar:
loss_buffer = []
for batch_input, batch_output in get_batch():
loss = one_step(batch_input, batch_output)
Expand All @@ -125,7 +127,7 @@ def one_step(batch_input, batch_output):
epoch_losses.append(epoch_loss)

if verbose:
print(f"Epoch {len(epoch_losses)}: {epoch_losses[-1]}")
pbar.set_description(f"Loss: {epoch_loss:.8f}")

if return_state_dict:
return (iter_losses, epoch_losses), state_dict_list
Expand Down
259 changes: 142 additions & 117 deletions sandbox/sandbox.ipynb

Large diffs are not rendered by default.

0 comments on commit 7390e06

Please sign in to comment.