Skip to content

Commit

Permalink
save configs feature #118
Browse files Browse the repository at this point in the history
  • Loading branch information
absadiki committed Mar 4, 2024
1 parent d44a248 commit f3270b7
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/subsai/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import importlib
import json
import mimetypes
import os.path
import shutil
Expand Down Expand Up @@ -305,16 +306,27 @@ def webui() -> None:
info = SubsAI.model_info(stt_model_name)
st.info(info['description'] + '\n' + info['url'])

configs_mode = st.selectbox("Select Configs Mode", ['Manual', 'Load from local file'], index=0,
help='Play manually with the model configs or load them from an exported json file.')

with st.sidebar.expander('Model Configs', expanded=False):
config_schema = SubsAI.config_schema(stt_model_name)
_generate_config_ui(stt_model_name, config_schema)

if configs_mode == 'Manual':
_generate_config_ui(stt_model_name, config_schema)
else:
configs_path = st.text_input('Configs path', help='Absolute path of the configs file')

transcribe_button = st.button('Transcribe', type='primary')
transcribe_loading_placeholder = st.empty()

if transcribe_button:
config_schema = SubsAI.config_schema(stt_model_name)
model_config = _get_config_from_session_state(stt_model_name, config_schema, notification_placeholder)
if configs_mode == 'Manual':
model_config = _get_config_from_session_state(stt_model_name, config_schema, notification_placeholder)
else:
with open(configs_path, 'r', encoding='utf-8') as f:
model_config = json.load(f)
subs = _transcribe(file_path, stt_model_name, model_config)
st.session_state['transcribed_subs'] = subs
transcribe_loading_placeholder.success('Done!', icon="✅")
Expand Down Expand Up @@ -540,6 +552,12 @@ def webui() -> None:
st.error("See the terminal for more info!")
print(e)

with st.expander('Export configs file'):
export_filename = st.text_input('Filename', value=f"{stt_model_name}_configs.json".replace('/', '-'))
configs_dict = _get_config_from_session_state(stt_model_name, config_schema, notification_placeholder)
st.download_button('Download', data=json.dumps(configs_dict), file_name=export_filename, mime='json')


st.markdown(footer, unsafe_allow_html=True)


Expand Down

0 comments on commit f3270b7

Please sign in to comment.