Skip to content

Commit

Permalink
with black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
StephanHaa9 committed Dec 13, 2024
1 parent 11b4166 commit 683b1da
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions znnl/training_recording/papyrus_jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,15 @@
-------
"""

from typing import List
from typing import List, Union

import numpy as np
from papyrus.measurements import BaseMeasurement
from papyrus.recorders import BaseRecorder

from znnl.analysis import JAXNTKComputation
from znnl.models import JaxModel

from typing import List, Union
import numpy as np


class JaxRecorder(BaseRecorder):
"""
Expand Down Expand Up @@ -76,7 +74,7 @@ class JaxRecorder(BaseRecorder):
An NTK computation module. For more information see the JAXNTKComputation
class.
"""

def __init__(
self,
name: str,
Expand Down Expand Up @@ -223,7 +221,6 @@ def _compute_neural_state(self, model: JaxModel):
predictions = model(self._data_set[list(self._data_set.keys())[0]])
self.neural_state["predictions"] = [predictions]


def record(self, epoch: int, model: JaxModel, **kwargs):
"""
Perform the recording of a neural state.
Expand Down Expand Up @@ -251,32 +248,30 @@ def record(self, epoch: int, model: JaxModel, **kwargs):
do_record = False
# if List[int]
if isinstance(self.recording_schedule, list):
# Check if the current epoch is in the schedule list
do_record = np.isin(epoch, self.recording_schedule)
# Check if the current epoch is in the schedule list
do_record = np.isin(epoch, self.recording_schedule)
# if int
elif isinstance(self.recording_schedule, int):
# Check if the current epoch is a multiple of the recording schedule
do_record = (epoch % self.recording_schedule == 0)
# Check if the current epoch is a multiple of the recording schedule
do_record = epoch % self.recording_schedule == 0
else:
raise ValueError(
raise ValueError(
f"Invalid type for recording_schedule: {type(self.recording_schedule)}. "
"Expected int or list of int."
)
)

# Perform recording if do_record is True
if do_record:
print(f"Recording at epoch {epoch}")
# Compute the neural state
self._compute_neural_state(model)
# Add all other kwargs to the neural state dictionary
self.neural_state.update(kwargs)
for key, val in self._data_set.items():
self.neural_state[key] = [val]
# Check if incoming data is complete
self._check_keys()
# Perform measurements
self._measure(**self.neural_state)
# Store the measurements
self.store(ignore_chunk_size=False)


print(f"Recording at epoch {epoch}")
# Compute the neural state
self._compute_neural_state(model)
# Add all other kwargs to the neural state dictionary
self.neural_state.update(kwargs)
for key, val in self._data_set.items():
self.neural_state[key] = [val]
# Check if incoming data is complete
self._check_keys()
# Perform measurements
self._measure(**self.neural_state)
# Store the measurements
self.store(ignore_chunk_size=False)

0 comments on commit 683b1da

Please sign in to comment.