Skip to content

Commit

Permalink
Merge pull request #4 from payalcha/remove-get-writer
Browse files Browse the repository at this point in the history
Remove get_writer function and initialize writer variable in calling function
  • Loading branch information
payalcha authored Dec 11, 2024
2 parents f5cd467 + 0fa09df commit 9cd3323
Show file tree
Hide file tree
Showing 16 changed files with 62 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,15 @@
"source": [
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)\n",
"writer = None\n",
"\n",
"def get_writer():\n",
" global writer\n",
" if not writer:\n",
" writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)\n",
"\n",
"def write_metric(node_name, task_name, metric_name, metric, round_number):\n",
" get_writer()\n",
" writer.add_scalar(\"{}/{}/{}\".format(node_name, task_name, metric_name),\n",
" metric, round_number)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
return writer


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer = get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@

from tensorflow.summary import SummaryWriter

writer = None

def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
return writer

def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer = get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
return writer


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer = get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
return writer


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer = get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)

0 comments on commit 9cd3323

Please sign in to comment.