Skip to content

Commit

Permalink
Fix mlcommons pytorch accuracy check script OOM on Orin AGX 64g (mlco…
Browse files Browse the repository at this point in the history
…mmons#1745)

* Fix accuracy_coco.py OOM error on AGX Orin, reduce memoryfrom 70G to less than 10G.

* Fix mlcommons pytorch accuracy check script OOM on Orin AGX 64g: bug fixing
  • Loading branch information
nvamberl authored Jun 27, 2024
1 parent c62fc11 commit 763883b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 33 deletions.
136 changes: 104 additions & 32 deletions text_to_image/tools/accuracy_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
import pandas as pd
import torch
from clip.clip_encoder import CLIPEncoder
from fid.fid_score import compute_fid
from fid.inception import InceptionV3
from fid.fid_score import compute_statistics_of_path, get_activations, calculate_frechet_distance
from tqdm import tqdm
import ijson



Expand Down Expand Up @@ -44,15 +47,10 @@ def preprocess_image(img_dir, file_name):

def main():
args = get_args()
result_dict = {}

# Load dataset annotations
df_captions = pd.read_csv(args.caption_path, sep="\t")

# Load model outputs
with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)

# set device
device = args.device if torch.cuda.is_available() else "cpu"
if device == "gpu":
Expand All @@ -79,39 +77,113 @@ def main():
for idx in compliance_images_idx_list:
caption_file.write(f"{idx} {df_captions.iloc[idx]['caption']}\n")

# Load torchmetrics modules
clip = CLIPEncoder(device=device)
# Compute accuracy
compute_accuracy(
args.mlperf_accuracy_file,
args.output_file,
device,
dump_compliance_images,
compliance_images_idx_list,
args.compliance_images_path,
df_captions,
statistics_path,
)


def compute_accuracy(
mlperf_accuracy_file,
output_file,
device,
dump_compliance_images,
compliance_images_idx_list,
compliance_images_path,
df_captions,
statistics_path,
batch_size=8,
inception_dims=2048,
num_workers=1,
):
if num_workers is None:
try:
num_cpus = len(os.sched_getaffinity(0))
except AttributeError:
# os.sched_getaffinity is not available under Windows, use
# os.cpu_count instead (which may not return the *available* number
# of CPUs).
num_cpus = os.cpu_count()

num_workers = min(num_cpus, 8) if num_cpus is not None else 0
else:
num_workers = num_workers

# Prepare models
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_dims]
inception_model = InceptionV3([block_idx]).to(device)
clip_model = CLIPEncoder(device=device)

clip_scores = []
seen = set()
result_list = []
for j in results:
idx = j['qsl_idx']
if idx in seen:
continue
seen.add(idx)

# Load generated image
generated_img = np.frombuffer(bytes.fromhex(j['data']), np.uint8).reshape(1024, 1024, 3)
result_list.append(generated_img)
generated_img = Image.fromarray(generated_img)

# Dump compliance images
if dump_compliance_images and idx in compliance_images_idx_list:
generated_img.save(os.path.join(args.compliance_images_path, f"{idx}.png"))

# generated_img = torch.Tensor(generated_img).to(torch.uint8).to(device)
# Load Ground Truth
caption = df_captions.iloc[idx]["caption"]
clip_scores.append(
100 * clip.get_clip_score(caption, generated_img).item()
)
fid_score = compute_fid(result_list, statistics_path, device)
result_batch = []
result_dict = {}
activations = np.empty((0, inception_dims))

# Load model outputs
with open(mlperf_accuracy_file, "r") as f:
results = ijson.items(f, "item")

for j in tqdm(results):
idx = j['qsl_idx']
if idx in seen:
continue
seen.add(idx)

# Load generated image
generated_img = np.frombuffer(bytes.fromhex(j['data']), np.uint8).reshape(1024, 1024, 3)
generated_img = Image.fromarray(generated_img)

# Dump compliance images
if dump_compliance_images and idx in compliance_images_idx_list:
generated_img.save(os.path.join(compliance_images_path, f"{idx}.png"))

# Load Ground Truth
caption = df_captions.iloc[idx]["caption"]
clip_scores.append(
100 * clip_model.get_clip_score(caption, generated_img).item()
)

result_batch.append(generated_img.convert("RGB"))

if len(result_batch) == batch_size:
act = get_activations(result_batch, inception_model, batch_size, inception_dims, device, num_workers)
activations = np.append(activations, act, axis=0)
result_batch.clear()

# Remaining data for last batch
if len(result_batch) > 0:
act = get_activations(result_batch, inception_model, len(result_batch), inception_dims, device, num_workers)
activations = np.append(activations, act, axis=0)

m1, s1 = compute_statistics_of_path(
statistics_path,
inception_model,
batch_size,
inception_dims,
device,
num_workers,
None,
None,
)

m2 = np.mean(activations, axis=0)
s2 = np.cov(activations, rowvar=False)

fid_score = calculate_frechet_distance(m1, s1, m2, s2)

result_dict["FID_SCORE"] = fid_score
result_dict["CLIP_SCORE"] = np.mean(clip_scores)
print(f"Accuracy Results: {result_dict}")

with open(args.output_file, "w") as fp:
with open(output_file, "w") as fp:
json.dump(result_dict, fp, sort_keys=True, indent=4)

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion text_to_image/tools/fid/fid_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_activations(

start_idx = 0

for batch in tqdm(dataloader):
for batch in dataloader:
batch = batch.to(device)

with torch.no_grad():
Expand Down

0 comments on commit 763883b

Please sign in to comment.