Skip to content

Commit

Permalink
Use SocketIO instead of REST for schema diff compare. #4841
Browse files Browse the repository at this point in the history
  • Loading branch information
pravesh-sharma authored Oct 21, 2022
1 parent 0384f55 commit 1647fc5
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 161 deletions.
186 changes: 91 additions & 95 deletions web/pgadmin/tools/schema_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
from pgadmin.utils.constants import PREF_LABEL_DISPLAY, MIMETYPE_APP_JS,\
ERROR_MSG_TRANS_ID_NOT_FOUND
from sqlalchemy import or_
from pgadmin.authenticate import socket_login_required
from pgadmin import socketio

MODULE_NAME = 'schema_diff'
COMPARE_MSG = gettext("Comparing objects...")
SOCKETIO_NAMESPACE = '/{0}'.format(MODULE_NAME)


class SchemaDiffModule(PgAdminModule):
Expand Down Expand Up @@ -59,9 +62,6 @@ def get_exposed_url_endpoints(self):
'schema_diff.servers',
'schema_diff.databases',
'schema_diff.schemas',
'schema_diff.compare_database',
'schema_diff.compare_schema',
'schema_diff.poll',
'schema_diff.ddl_compare',
'schema_diff.connect_server',
'schema_diff.connect_database',
Expand Down Expand Up @@ -436,39 +436,38 @@ def schemas(sid, did):
return make_json_response(data=res)


@blueprint.route(
'/compare_database/<int:trans_id>/<int:source_sid>/<int:source_did>/'
'<int:target_sid>/<int:target_did>/<int:ignore_owner>/'
'<int:ignore_whitespaces>',
methods=["GET"],
endpoint="compare_database"
)
@login_required
def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
ignore_owner, ignore_whitespaces):
@socketio.on('compare_database', namespace=SOCKETIO_NAMESPACE)
@socket_login_required
def compare_database(params):
"""
This function will compare the two databases.
"""
# Check the pre validation before compare
status, error_msg, diff_model_obj, session_obj = \
compare_pre_validation(trans_id, source_sid, target_sid)
compare_pre_validation(params['trans_id'], params['source_sid'],
params['target_sid'])
if not status:
socketio.emit('compare_database_failed', error_msg,
namespace=SOCKETIO_NAMESPACE, to=request.sid)
return error_msg

comparison_result = []

diff_model_obj.set_comparison_info(COMPARE_MSG, 0)
update_session_diff_transaction(trans_id, session_obj,
socketio.emit('compare_status', {'diff_percentage': 0,
'compare_msg': COMPARE_MSG}, namespace=SOCKETIO_NAMESPACE,
to=request.sid)
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)

try:
ignore_owner = bool(ignore_owner)
ignore_whitespaces = bool(ignore_whitespaces)
ignore_owner = bool(params['ignore_owner'])
ignore_whitespaces = bool(params['ignore_whitespaces'])

# Fetch all the schemas of source and target database
# Compare them and get the status.
schema_result = fetch_compare_schemas(source_sid, source_did,
target_sid, target_did)
schema_result = \
fetch_compare_schemas(params['source_sid'], params['source_did'],
params['target_sid'], params['target_did'])

total_schema = len(schema_result['source_only']) + len(
schema_result['target_only']) + len(
Expand All @@ -483,12 +482,13 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
# Compare Database objects
comparison_schema_result, total_percent = \
compare_database_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
target_sid=target_sid, target_did=target_did,
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
target_sid=params['target_sid'],
target_did=params['target_did'],
diff_model_obj=diff_model_obj, total_percent=total_percent,
node_percent=node_percent,
ignore_owner=ignore_owner,
node_percent=node_percent, ignore_owner=ignore_owner,
ignore_whitespaces=ignore_whitespaces)
comparison_result = \
comparison_result + comparison_schema_result
Expand All @@ -499,10 +499,12 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
for item in schema_result['source_only']:
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=item['scid'], target_sid=target_sid,
target_did=target_did, target_scid=None,
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=item['scid'],
target_sid=params['target_sid'],
target_did=params['target_did'], target_scid=None,
schema_name=item['schema_name'],
diff_model_obj=diff_model_obj,
total_percent=total_percent,
Expand All @@ -519,10 +521,12 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
for item in schema_result['target_only']:
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=None, target_sid=target_sid,
target_did=target_did, target_scid=item['scid'],
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=None, target_sid=params['target_sid'],
target_did=params['target_did'],
target_scid=item['scid'],
schema_name=item['schema_name'],
diff_model_obj=diff_model_obj,
total_percent=total_percent,
Expand All @@ -539,10 +543,13 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
for item in schema_result['in_both_database']:
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=item['src_scid'], target_sid=target_sid,
target_did=target_did, target_scid=item['tar_scid'],
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=item['src_scid'],
target_sid=params['target_sid'],
target_did=params['target_did'],
target_scid=item['tar_scid'],
schema_name=item['schema_name'],
diff_model_obj=diff_model_obj,
total_percent=total_percent,
Expand All @@ -555,54 +562,54 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,

msg = gettext("Successfully compare the specified databases.")
total_percent = 100
diff_model_obj.set_comparison_info(msg, total_percent)
# Update the message and total percentage done in session object
update_session_diff_transaction(trans_id, session_obj, diff_model_obj)
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)

except Exception as e:
app.logger.exception(e)
socketio.emit('compare_database_failed', str(e),
namespace=SOCKETIO_NAMESPACE, to=request.sid)

return make_json_response(data=comparison_result)
socketio.emit('compare_database_success', comparison_result,
namespace=SOCKETIO_NAMESPACE, to=request.sid)


@blueprint.route(
'/compare_schema/<int:trans_id>/<int:source_sid>/<int:source_did>/'
'<int:source_scid>/<int:target_sid>/<int:target_did>/<int:target_scid>/'
'<int:ignore_owner>/<int:ignore_whitespaces>',
methods=["GET"],
endpoint="compare_schema"
)
@login_required
def compare_schema(trans_id, source_sid, source_did, source_scid,
target_sid, target_did, target_scid, ignore_owner,
ignore_whitespaces):
@socketio.on('compare_schema', namespace=SOCKETIO_NAMESPACE)
@socket_login_required
def compare_schema(params):
"""
This function will compare the two schema.
"""
# Check the pre validation before compare
status, error_msg, diff_model_obj, session_obj = \
compare_pre_validation(trans_id, source_sid, target_sid)
compare_pre_validation(params['trans_id'], params['source_sid'],
params['target_sid'])
if not status:
socketio.emit('compare_schema_failed', error_msg,
namespace=SOCKETIO_NAMESPACE, to=request.sid)
return error_msg

comparison_result = []

diff_model_obj.set_comparison_info(COMPARE_MSG, 0)
update_session_diff_transaction(trans_id, session_obj,
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)
try:
ignore_owner = bool(ignore_owner)
ignore_whitespaces = bool(ignore_whitespaces)
ignore_owner = bool(params['ignore_owner'])
ignore_whitespaces = bool(params['ignore_whitespaces'])
all_registered_nodes = SchemaDiffRegistry.get_registered_nodes()
node_percent = round(100 / len(all_registered_nodes))
total_percent = 0

comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=source_scid, target_sid=target_sid,
target_did=target_did, target_scid=target_scid,
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=params['source_scid'],
target_sid=params['target_sid'],
target_did=params['target_did'],
target_scid=params['target_scid'],
schema_name=gettext('Schema Objects'),
diff_model_obj=diff_model_obj,
total_percent=total_percent,
Expand All @@ -615,43 +622,16 @@ def compare_schema(trans_id, source_sid, source_did, source_scid,

msg = gettext("Successfully compare the specified schemas.")
total_percent = 100
diff_model_obj.set_comparison_info(msg, total_percent)
# Update the message and total percentage done in session object
update_session_diff_transaction(trans_id, session_obj, diff_model_obj)
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)

except Exception as e:
app.logger.exception(e)

return make_json_response(data=comparison_result)


@blueprint.route(
'/poll/<int:trans_id>', methods=["GET"], endpoint="poll"
)
@login_required
def poll(trans_id):
"""
This function is used to check the schema comparison is completed or not.
:param trans_id:
:return:
"""

# Check the transaction and connection status
status, error_msg, diff_model_obj, session_obj = \
check_transaction_status(trans_id)

if error_msg == ERROR_MSG_TRANS_ID_NOT_FOUND:
return make_json_response(success=0, errormsg=error_msg, status=404)

msg, diff_percentage = diff_model_obj.get_comparison_info()

if diff_percentage == 100:
diff_model_obj.set_comparison_info(COMPARE_MSG, 0)
update_session_diff_transaction(trans_id, session_obj,
diff_model_obj)

return make_json_response(data={'compare_msg': msg,
'diff_percentage': diff_percentage})
socketio.emit('compare_schema_failed', str(e),
namespace=SOCKETIO_NAMESPACE, to=request.sid)
socketio.emit('compare_schema_success', comparison_result,
namespace=SOCKETIO_NAMESPACE, to=request.sid)


@blueprint.route(
Expand Down Expand Up @@ -781,7 +761,9 @@ def compare_database_objects(**kwargs):
msg = gettext('Comparing {0}'). \
format(gettext(view.blueprint.collection_label))
app.logger.debug(msg)
diff_model_obj.set_comparison_info(msg, total_percent)
socketio.emit('compare_status', {'diff_percentage': total_percent,
'compare_msg': msg}, namespace=SOCKETIO_NAMESPACE,
to=request.sid)
# Update the message and total percentage in session object
update_session_diff_transaction(trans_id, session_obj,
diff_model_obj)
Expand Down Expand Up @@ -843,7 +825,9 @@ def compare_schema_objects(**kwargs):
format(gettext(view.blueprint.collection_label),
gettext(schema_name))
app.logger.debug(msg)
diff_model_obj.set_comparison_info(msg, total_percent)
socketio.emit('compare_status', {'diff_percentage': total_percent,
'compare_msg': msg}, namespace=SOCKETIO_NAMESPACE,
to=request.sid)
# Update the message and total percentage in session object
update_session_diff_transaction(trans_id, session_obj,
diff_model_obj)
Expand Down Expand Up @@ -943,3 +927,15 @@ def compare_pre_validation(trans_id, source_sid, target_sid):
return False, res, None, None

return True, '', diff_model_obj, session_obj


@socketio.on('connect', namespace=SOCKETIO_NAMESPACE)
def connect():
"""
Connect to the server through socket.
:return:
:rtype:
"""
socketio.emit('connected', {'sid': request.sid},
namespace=SOCKETIO_NAMESPACE,
to=request.sid)
19 changes: 0 additions & 19 deletions web/pgadmin/tools/schema_diff/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def __init__(self, **kwargs):
**kwargs : N number of parameters
"""
self._comparison_result = dict()
self._comparison_msg = gettext('Comparision started...')
self._comparison_percentage = 0

def clear_data(self):
"""
Expand All @@ -59,20 +57,3 @@ def get_result(self, node_name=None):
return self._comparison_result[node_name]

return self._comparison_result

def get_comparison_info(self):
"""
This function is used to get the comparison information.
:return:
"""
return self._comparison_msg, self._comparison_percentage

def set_comparison_info(self, msg, percentage):
"""
This function is used to set the comparison information.
:param msg:
:param percentage:
:return:
"""
self._comparison_msg = msg
self._comparison_percentage = percentage
Loading

0 comments on commit 1647fc5

Please sign in to comment.