diff --git a/trestle/tasks/oscal_catalog_to_csv.py b/trestle/tasks/oscal_catalog_to_csv.py index b68ed8457..0003321b0 100644 --- a/trestle/tasks/oscal_catalog_to_csv.py +++ b/trestle/tasks/oscal_catalog_to_csv.py @@ -70,16 +70,11 @@ def convert_control_id(control_id: str) -> str: def convert_smt_id(smt_id: str) -> str: """Convert smt id.""" - parts = smt_id.split('_smt') - seg1 = convert_control_id(parts[0]) - seg2 = '' - if len(parts) == 2: - seg2 = parts[1] - if '.' in seg2: - seg2 = seg2.replace('.', '(') - seg2 = seg2 + ')' - rval = f'{seg1}{seg2}' - return rval + parts = smt_id.split('_') + control_id = convert_control_id(parts[0]) + sub_ids = parts[1].split('.')[1:] + sub_id = ''.join(f'({s})' for s in sub_ids) + return f'{control_id}{sub_id}' class CsvHelper: @@ -118,6 +113,15 @@ def _init_control_parent_map(self, recurse=True) -> None: raise RuntimeError('{parent} duplicate?') self._control_parent_map[parent] = control + def derive_text(self, control: Control, part: Part) -> Optional[str]: + """Derive control text.""" + rval = None + if part.prose: + id_ = self._derive_id(part.id) + text = self._resolve_parms(control, part.prose) + rval = join_str(id_, text) + return rval + def get_parent_control(self, ctl_id: str) -> Control: """Return parent Control of child Control.id, if any.""" return self._control_parent_map.get(ctl_id) @@ -141,16 +145,6 @@ def get_statement_text_for_control(self, control: Control) -> Optional[str]: statement_text = self._withdrawn(control) return statement_text - def get_statement_text_for_part(self, control: Control, part: Part) -> Optional[str]: - """Get statement text for part.""" - statement_text = self._derive_text(control, part) - if part.parts: - for subpart in part.parts: - if '_smt' in subpart.id: - partial_text = self._derive_text(control, subpart) - statement_text = join_str(statement_text, partial_text) - return statement_text - def _withdrawn(self, control: Control) -> Optional[str]: """Check if withdrawn.""" rval = None @@ -192,15 +186,6 @@ def _href_to_control(self, href: str) -> str: rval = href.replace('#', '').upper() return rval - def _derive_text(self, control: Control, part: Part) -> Optional[str]: - """Derive control text.""" - rval = None - if part.prose: - id_ = self._derive_id(part.id) - text = self._resolve_parms(control, part.prose) - rval = join_str(id_, text) - return rval - def _derive_id(self, id_: str) -> str: """Derive control text sub-part id.""" rval = None @@ -319,30 +304,24 @@ def _get_content_by_statement(self) -> List: self.add(row) return self.rows - def _add_subparts_by_statement(self, control: Control, part: Part) -> None: - """Add subparts by statement.""" + def _add_statements_recursively(self, control: Control, part: Part) -> None: + """Add parts and subparts recursively.""" catalog_helper = self.catalog_helper control_id = convert_control_id(control.id) - for subpart in part.parts: - if '_smt' in subpart.id: - statement_text = catalog_helper.get_statement_text_for_part(control, subpart) - row = [control_id, control.title, convert_smt_id(subpart.id), statement_text] - self.add(row) + + if part.id and '_smt' in part.id: + statement_text = catalog_helper.derive_text(control, part) + row = [control_id, control.title, convert_smt_id(part.id), statement_text] + self.add(row) + + if part.parts: + for subpart in part.parts: + self._add_statements_recursively(control, subpart) def _add_parts_by_statement(self, control: Control) -> None: """Add parts by statement.""" - catalog_helper = self.catalog_helper - control_id = convert_control_id(control.id) for part in control.parts: - if part.id: - if '_smt' not in part.id: - continue - if part.parts: - self._add_subparts_by_statement(control, part) - else: - statement_text = catalog_helper.get_statement_text_for_part(control, part) - row = [control_id, control.title, convert_smt_id(part.id), statement_text] - self.add(row) + self._add_statements_recursively(control, part) def _get_content_by_control(self) -> List: """Get content by statement."""