diff --git a/src/DIRAC/Core/Security/ProxyFile.py b/src/DIRAC/Core/Security/ProxyFile.py index 22721934912..cb35144b4ce 100644 --- a/src/DIRAC/Core/Security/ProxyFile.py +++ b/src/DIRAC/Core/Security/ProxyFile.py @@ -6,6 +6,7 @@ from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.Utilities.File import secureOpenForWrite from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.Core.Security.Locations import getProxyLocation @@ -17,22 +18,11 @@ def writeToProxyFile(proxyContents, fileName=False): - proxyContents : string object to dump to file - fileName : filename to dump to """ - if not fileName: - try: - fd, proxyLocation = tempfile.mkstemp() - os.close(fd) - except OSError: - return S_ERROR(DErrno.ECTMPF) - fileName = proxyLocation try: - with open(fileName, "w") as fd: + with secureOpenForWrite(fileName) as fd: fd.write(proxyContents) except Exception as e: return S_ERROR(DErrno.EWF, f" {fileName}: {repr(e).replace(',)', ')')}") - try: - os.chmod(fileName, stat.S_IRUSR | stat.S_IWUSR) - except Exception as e: - return S_ERROR(DErrno.ESPF, f"{fileName}: {repr(e).replace(',)', ')')}") return S_OK(fileName) diff --git a/src/DIRAC/Core/Security/m2crypto/X509CRL.py b/src/DIRAC/Core/Security/m2crypto/X509CRL.py index c1c4e27aefa..14494b19d9d 100644 --- a/src/DIRAC/Core/Security/m2crypto/X509CRL.py +++ b/src/DIRAC/Core/Security/m2crypto/X509CRL.py @@ -1,15 +1,13 @@ """ X509CRL is a class for managing X509CRL This class is used to manage the revoked certificates.... """ -import stat -import os -import tempfile import re import datetime import M2Crypto from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.Utilities.File import secureOpenForWrite # pylint: disable=broad-except @@ -72,17 +70,10 @@ def dumpAllToFile(self, filename=False): if not self.__loadedCert: return S_ERROR("No certificate loaded") try: - if not filename: - fd, filename = tempfile.mkstemp() - os.close(fd) - with open(filename, "w", encoding="ascii") as fd: + with secureOpenForWrite(filename) as fd: fd.write(self.__pemData) except Exception as e: return S_ERROR(DErrno.EWF, f"{filename}: {repr(e).replace(',)', ')')}") - try: - os.chmod(filename, stat.S_IRUSR | stat.S_IWUSR) - except Exception as e: - return S_ERROR(DErrno.ESPF, f"{filename}: {repr(e).replace(',)', ')')}") return S_OK(filename) def hasExpired(self): diff --git a/src/DIRAC/Core/Security/m2crypto/X509Chain.py b/src/DIRAC/Core/Security/m2crypto/X509Chain.py index ea5e77720f2..ea1339925bb 100644 --- a/src/DIRAC/Core/Security/m2crypto/X509Chain.py +++ b/src/DIRAC/Core/Security/m2crypto/X509Chain.py @@ -8,9 +8,6 @@ """ import copy -import os -import stat -import tempfile import hashlib import re @@ -21,6 +18,7 @@ from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities import DErrno from DIRAC.Core.Utilities.Decorators import executeOnlyIf, deprecated +from DIRAC.Core.Utilities.File import secureOpenForWrite from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.Core.Security.m2crypto import PROXY_OID, LIMITED_PROXY_OID, DIRAC_GROUP_OID, DEFAULT_PROXY_STRENGTH from DIRAC.Core.Security.m2crypto.X509Certificate import X509Certificate @@ -492,14 +490,10 @@ def generateProxyToFile(self, filePath, lifetime, diracGroup=False, strength=DEF if not retVal["OK"]: return retVal try: - with open(filePath, "w") as fd: + with secureOpenForWrite(filePath) as fd: fd.write(retVal["Value"]) except Exception as e: return S_ERROR(DErrno.EWF, f"{filePath} :{repr(e).replace(',)', ')')}") - try: - os.chmod(filePath, stat.S_IRUSR | stat.S_IWUSR) - except Exception as e: - return S_ERROR(DErrno.ESPF, f"{filePath} :{repr(e).replace(',)', ')')}") return S_OK() @needCertList @@ -880,17 +874,10 @@ def dumpAllToFile(self, filename=False): return retVal pemData = retVal["Value"] try: - if not filename: - fd, filename = tempfile.mkstemp() - os.close(fd) - with open(filename, "w") as fp: - fp.write(pemData) + with secureOpenForWrite(filename) as fh: + fh.write(pemData) except Exception as e: return S_ERROR(DErrno.EWF, f"{filename} :{repr(e).replace(',)', ')')}") - try: - os.chmod(filename, stat.S_IRUSR | stat.S_IWUSR) - except Exception as e: - return S_ERROR(DErrno.ESPF, f"{filename} :{repr(e).replace(',)', ')')}") return S_OK(filename) @needCertList diff --git a/src/DIRAC/Core/Utilities/File.py b/src/DIRAC/Core/Utilities/File.py index e01953a4566..f61d75ea63d 100755 --- a/src/DIRAC/Core/Utilities/File.py +++ b/src/DIRAC/Core/Utilities/File.py @@ -11,6 +11,9 @@ import sys import re import errno +import stat +import tempfile +from contextlib import contextmanager # Translation table of a given unit to Bytes # I know, it should be kB... @@ -253,6 +256,27 @@ def convertSizeUnits(size, srcUnit, dstUnit): return -sys.maxsize +@contextmanager +def secureOpenForWrite(filename=None, *, text=True): + """Securely open a file for writing. + + If filename is not provided, a file is created in tempfile.gettempdir(). + The file always created with mode 600. + + :param string filename: name of file to be opened + """ + if filename: + fd = os.open( + path=filename, + flags=os.O_WRONLY | os.O_CREAT | os.O_TRUNC, + mode=stat.S_IRUSR | stat.S_IWUSR, + ) + else: + fd, filename = tempfile.mkstemp(text=text) + with open(fd, "w" if text else "wb", encoding="ascii") as fd: + yield fd + + if __name__ == "__main__": for p in sys.argv[1:]: print(f"{p} : {getGlobbedTotalSize(p)} bytes") diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index c110e53fd35..eacae1cb754 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -1,13 +1,13 @@ import os import re import jwt -import stat import time import json import datetime from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.Utilities.File import secureOpenForWrite from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory @@ -83,14 +83,10 @@ def writeToTokenFile(tokenContents, fileName): """ location = getTokenFileLocation(fileName) try: - with open(location, "w") as fd: + with secureOpenForWrite(location) as fd: fd.write(tokenContents) except Exception as e: return S_ERROR(DErrno.EWF, f" {location}: {repr(e)}") - try: - os.chmod(location, stat.S_IRUSR | stat.S_IWUSR) - except Exception as e: - return S_ERROR(DErrno.ESPF, f"{location}: {repr(e)}") return S_OK(location) diff --git a/src/DIRAC/Interfaces/Utilities/DConfigCache.py b/src/DIRAC/Interfaces/Utilities/DConfigCache.py index feeca2375d2..3e9cab14213 100644 --- a/src/DIRAC/Interfaces/Utilities/DConfigCache.py +++ b/src/DIRAC/Interfaces/Utilities/DConfigCache.py @@ -1,12 +1,12 @@ #!/usr/bin/env python import os import re -import stat import time import pickle import tempfile from DIRAC.Core.Base.Script import Script +from DIRAC.Core.Utilities.File import secureOpenForWrite from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData @@ -67,8 +67,7 @@ def cacheConfig(self): if self.newConfig: self.__cleanCacheDirectory() - with open(self.configCacheName, "wb") as fcache: - os.chmod(self.configCacheName, stat.S_IRUSR | stat.S_IWUSR) + with secureOpenForWrite(self.configCacheName, text=False) as fcache: pickle.dump(gConfigurationData.mergedCFG, fcache) else: try: diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/PilotWrapper.py b/src/DIRAC/WorkloadManagementSystem/Utilities/PilotWrapper.py index ae55088c3f3..8156c108e4b 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/PilotWrapper.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/PilotWrapper.py @@ -34,6 +34,7 @@ from __future__ import print_function import os +import io import stat import tempfile import sys @@ -130,7 +131,8 @@ def pilotWrapperScript( for pfName, encodedPf in pilotFilesCompressedEncodedDict.items(): compressedString += """ try: - with open('%(pfName)s', 'wb') as fd: + fd = os.open('%(pfName)s', os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IRUSR | stat.S_IWUSR) + with io.open(fd, 'wb') as fd: if sys.version_info < (3,): fd.write(bz2.decompress(base64.b64decode(\"\"\"%(encodedPf)s\"\"\"))) else: