Skip to content

Commit

Permalink
Feat recognize ramdisk (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
RangerCD authored Nov 18, 2020
1 parent c99dcf2 commit 7b59b1f
Showing 1 changed file with 84 additions and 55 deletions.
139 changes: 84 additions & 55 deletions mounts_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,41 +58,39 @@ func getClusterInfo(guidOrMountPointBuf []uint16) (totalClusters uint32, cluster
return
}

func getMountFromGUID(guidBuf []uint16) (m Mount, skip bool, warnings []string) {
func getMount(guidOrMountPointBuf []uint16, isGUID bool) (m Mount, skip bool, warnings []string) {
var err error
guid := windows.UTF16ToString(guidBuf)
guidOrMountPoint := windows.UTF16ToString(guidOrMountPointBuf)

mountPoint, err := getMountPoint(guidBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", guid, err))
}
// Skip unmounted volumes
if len(mountPoint) == 0 {
skip = true
return
mountPoint := guidOrMountPoint
if isGUID {
mountPoint, err = getMountPoint(guidOrMountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", guidOrMountPoint, err))
}
// Skip unmounted volumes
if len(mountPoint) == 0 {
skip = true
return
}
}

// Get volume name & filesystem type
volumeName, fsType, err := getVolumeInfo(guidBuf)
volumeName, fsType, err := getVolumeInfo(guidOrMountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", guid, err))
warnings = append(warnings, fmt.Sprintf("%s: %s", guidOrMountPoint, err))
}

// Get space info
totalBytes, freeBytes, err := getSpaceInfo(guidBuf)
totalBytes, freeBytes, err := getSpaceInfo(guidOrMountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", guid, err))
warnings = append(warnings, fmt.Sprintf("%s: %s", guidOrMountPoint, err))
}

// Get cluster info
totalClusters, clusterSize, err := getClusterInfo(guidBuf)
totalClusters, clusterSize, err := getClusterInfo(guidOrMountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", guid, err))
}

// Use GUID as volume name if no label was set
if len(volumeName) == 0 {
volumeName = guid
warnings = append(warnings, fmt.Sprintf("%s: %s", guidOrMountPoint, err))
}

m = Mount{
Expand All @@ -111,6 +109,28 @@ func getMountFromGUID(guidBuf []uint16) (m Mount, skip bool, warnings []string)
return
}

func getMountFromGUID(guidBuf []uint16) (m Mount, skip bool, warnings []string) {
m, skip, warnings = getMount(guidBuf, true)

// Use GUID as volume name if no label was set
if len(m.Device) == 0 {
m.Device = windows.UTF16ToString(guidBuf)
}

return
}

func getMountFromMountPoint(mountPointBuf []uint16) (m Mount, warnings []string) {
m, _, warnings = getMount(mountPointBuf, false)

// Use mount point as volume name if no label was set
if len(m.Device) == 0 {
m.Device = windows.UTF16ToString(mountPointBuf)
}

return m, warnings
}

func appendLocalMounts(mounts []Mount, warnings []string) ([]Mount, []string, error) {
guidBuf := make([]uint16, guidBufLen)

Expand Down Expand Up @@ -145,50 +165,19 @@ VolumeLoop:

// Network devices
func getMountFromNetResource(netResource NetResource) (m Mount, warnings []string) {

mountPoint := windows.UTF16PtrToString(netResource.LocalName)
if !strings.HasSuffix(mountPoint, string(filepath.Separator)) {
mountPoint += string(filepath.Separator)
}
mountPointBuf := windows.StringToUTF16(mountPoint)

// Get volume name & filesystem type
volumeName, fsType, err := getVolumeInfo(mountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", mountPoint, err))
}

// Get space info
totalBytes, freeBytes, err := getSpaceInfo(mountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", mountPoint, err))
}

// Get cluster info
totalClusters, clusterSize, err := getClusterInfo(mountPointBuf)
if err != nil {
warnings = append(warnings, fmt.Sprintf("%s: %s", mountPoint, err))
}
m, _, warnings = getMount(mountPointBuf, false)

// Use remote name as volume name if no label was set
if len(volumeName) == 0 {
volumeName = windows.UTF16PtrToString(netResource.RemoteName)
if len(m.Device) == 0 {
m.Device = windows.UTF16PtrToString(netResource.RemoteName)
}

m = Mount{
Device: volumeName,
Mountpoint: mountPoint,
Fstype: fsType,
Type: fsType,
Opts: "",
Total: totalBytes,
Free: freeBytes,
Used: totalBytes - freeBytes,
Blocks: uint64(totalClusters),
BlockSize: uint64(clusterSize),
Metadata: &netResource,
}
m.DeviceType = deviceType(m)
return
}

Expand Down Expand Up @@ -228,6 +217,42 @@ EnumLoop:
return mounts, warnings, nil
}

func mountPointAlreadyPresent(mounts []Mount, mountPoint string) bool {
for _, m := range mounts {
if m.Mountpoint == mountPoint {
return true
}
}

return false
}

func appendLogicalDrives(mounts []Mount, warnings []string) ([]Mount, []string) {
driveBitmap, err := windows.GetLogicalDrives()
if err != nil {
warnings = append(warnings, fmt.Sprintf("GetLogicalDrives(): %s", err))
return mounts, warnings
}

for drive := 'A'; drive <= 'Z'; drive, driveBitmap = drive+1, driveBitmap>>1 {
if driveBitmap&0x1 == 0 {
continue
}

mountPoint := fmt.Sprintf("%c:\\", drive)
if mountPointAlreadyPresent(mounts, mountPoint) {
continue
}

mountPointBuf := windows.StringToUTF16(mountPoint)
m, w := getMountFromMountPoint(mountPointBuf)
mounts = append(mounts, m)
warnings = append(warnings, w...)
}

return mounts, warnings
}

func mounts() (ret []Mount, warnings []string, err error) {
ret = make([]Mount, 0)

Expand All @@ -241,6 +266,10 @@ func mounts() (ret []Mount, warnings []string, err error) {
return
}

// Logical devices (from GetLogicalDrives bitflag)
// Check any possible logical drives, in case of some special virtual devices, such as RAM disk
ret, warnings = appendLogicalDrives(ret, warnings)

return ret, warnings, nil
}

Expand Down

0 comments on commit 7b59b1f

Please sign in to comment.