diff --git a/hrms/api/roster.py b/hrms/api/roster.py index 5a363068fc..05c8f0ca1b 100644 --- a/hrms/api/roster.py +++ b/hrms/api/roster.py @@ -2,6 +2,8 @@ from frappe import _ from frappe.utils import add_days, date_diff +from erpnext.setup.doctype.employee.employee import get_holiday_list_for_employee + from hrms.hr.doctype.shift_assignment.shift_assignment import ShiftAssignment from hrms.hr.doctype.shift_assignment_tool.shift_assignment_tool import create_shift_assignment @@ -21,7 +23,6 @@ def get_events( events = {} for event in [holidays, leaves, shifts]: - event = group_by_employee(event) for key, value in event.items(): if key in events: events[key].extend(value) @@ -159,30 +160,17 @@ def insert_shift( def get_holidays(month_start: str, month_end: str, employee_filters: dict[str, str]) -> dict[str, list[dict]]: - Employee = frappe.qb.DocType("Employee") - HolidayList = frappe.qb.DocType("Holiday List") - Holiday = frappe.qb.DocType("Holiday") + holidays = {} - query = ( - frappe.qb.select( - Employee.employee, - Holiday.name.as_("holiday"), - Holiday.holiday_date, - Holiday.description, - Holiday.weekly_off, - ) - .from_(Employee) - .join(HolidayList) - .on(Employee.holiday_list == HolidayList.name) - .join(Holiday) - .on(Holiday.parent == HolidayList.name) - .where(Holiday.holiday_date[month_start:month_end]) - ) - - for filter in employee_filters: - query = query.where(Employee[filter] == employee_filters[filter]) + for employee in frappe.get_list("Employee", filters=employee_filters, pluck="name"): + if holiday_list := get_holiday_list_for_employee(employee, raise_exception=False): + holidays[employee] = frappe.get_all( + "Holiday", + filters={"parent": holiday_list, "holiday_date": ["between", [month_start, month_end]]}, + fields=["name as holiday", "holiday_date", "description", "weekly_off"], + ) - return query.run(as_dict=True) + return holidays def get_leaves(month_start: str, month_end: str, employee_filters: dict[str, str]) -> dict[str, list[dict]]: @@ -211,7 +199,7 @@ def get_leaves(month_start: str, month_end: str, employee_filters: dict[str, str for filter in employee_filters: query = query.where(Employee[filter] == employee_filters[filter]) - return query.run(as_dict=True) + return group_by_employee(query.run(as_dict=True)) def get_shifts( @@ -251,7 +239,7 @@ def get_shifts( for filter in shift_filters: query = query.where(ShiftAssignment[filter] == shift_filters[filter]) - return query.run(as_dict=True) + return group_by_employee(query.run(as_dict=True)) def group_by_employee(events: list[dict]) -> dict[str, list[dict]]: