Skip to content

Commit

Permalink
Merge pull request #1048 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Oct 11, 2024
2 parents 833c7f1 + 086b993 commit 0b426e4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
41 changes: 28 additions & 13 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,34 @@ def __init__(
**extra_args,
)

def exists(self, s3_key) -> bool:
"""Determine whether a file exists in S3."""
try:
# logger.debug(f"Checking if file exists: {s3_key}")
self.client.head_object(Bucket=self.bucket_name, Key=str(s3_key))
return True
# Catch the error when the file does not exist
except (Exception, self.client.exceptions.NoSuchKey) as e:
if "Not Found" not in str(e) and "Bad Request" not in str(e):
raise
return False
except:
return False
def exists(self, s3_key):
"""Check if the key exists in S3, with retries for transient errors."""
for i in range(self.read_retry_limit):
try:
self.client.head_object(Bucket=self.bucket_name, Key=str(s3_key))
return True
except self.client.exceptions.NoSuchKey:
logger.debug(
f"File {s3_key} does not exist in S3 bucket ({self.bucket_name})"
)
return False
except (NoCredentialsError, PartialCredentialsError) as e:
raise e # Raise credential errors to the caller
except Exception as e:
logger.error(f'Error checking existence of S3 key "{s3_key}": {e}')
if i == self.read_retry_limit - 1:
# We have reached our maximum retry count.
raise e
else:
# Sleep for a bit before retrying.
time.sleep(self.read_retry_interval)
except:
if i == self.read_retry_limit - 1:
# We have reached our maximum retry count.
raise
else:
# Sleep for a bit before retrying.
time.sleep(self.read_retry_interval)

def read(self, s3_key):
"""Retrieve and return the content of the file from S3."""
Expand Down
8 changes: 6 additions & 2 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ def safety_check(args, accelerator):
args.base_model_precision = "int8-quanto"

if (
(args.base_model_precision in ["fp8-quanto", "int4-quanto"] or (args.base_model_precision != "no_change" and args.quantize_activations))
and (accelerator is not None and accelerator.state.dynamo_plugin.backend.lower() == "inductor")
args.base_model_precision in ["fp8-quanto", "int4-quanto"]
or (args.base_model_precision != "no_change" and args.quantize_activations)
) and (
accelerator is not None
and accelerator.state.dynamo_plugin.backend.lower() == "inductor"
):
logger.warning(
f"{args.base_model_precision} is not supported with Dynamo backend. Disabling Dynamo."
)
from accelerate.utils import DynamoBackend

accelerator.state.dynamo_plugin.backend = DynamoBackend.NO
if args.report_to == "wandb":
if not is_wandb_available():
Expand Down
5 changes: 4 additions & 1 deletion helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,10 @@ def process_prompts(self):
self._log_validations_to_webhook(validation_images, shortname, prompt)
logger.debug(f"Completed generating image: {prompt}")
self.validation_images = validation_images
self._log_validations_to_trackers(validation_images)
try:
self._log_validations_to_trackers(validation_images)
except Exception as e:
logger.error(f"Error logging validation images: {e}")

def stitch_conditioning_images(self, validation_image_results, conditioning_image):
"""
Expand Down

0 comments on commit 0b426e4

Please sign in to comment.