-
Notifications
You must be signed in to change notification settings - Fork 0
/
tflite_utils.py
69 lines (55 loc) · 1.82 KB
/
tflite_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import concurrent.futures
import collections
import dataclasses
import hashlib
import itertools
import json
import math
import os
import pathlib
import random
import re
import string
import time
import urllib.request
import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import requests
import tqdm
import tensorflow as tf
def save_tflites(model, exp_name, epoch):
if not os.path.isdir(f'{exp_name}/{epoch}'):
os.makedirs(f'{exp_name}/{epoch}')
model.save(f'{exp_name}/{epoch}/HandSeg')
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(f'{exp_name}/{epoch}/HandSeg') # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open(f'{exp_name}/{epoch}/HandSeg.tflite', 'wb') as f:
f.write(tflite_model)
def save_tflites_from_saved_models(saved_model_dir, tflite_path):
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open(tflite_path, 'wb') as f:
f.write(tflite_model)
def save_tflites_ckpt(model, ckpt_path, tflite_path):
## model.set_training(False)
## checkpoint = tf.train.Checkpoint(net=model)
## if ckpt_path != '':
## if os.path.exists(ckpt_path):
## print(f'Restoring checkpoint from {ckpt_path}', end='\r')
## checkpoint.restore(tf.train.latest_checkpoint(ckpt_path))
model.load_weights(ckpt_path)
model.save(f'{ckpt_path}/HandSeg')
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(f'{ckpt_path}/HandSeg') # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open(tflite_path, 'wb') as f:
f.write(tflite_model)
model.set_training(True)