Fix shared memory leak on Windows (#3319)

* Fix shared memory leak on Windows

* Fix memory leak caused by RO session disposal not decrementing the memory manager ref count

* Fix UnmapViewInternal deadlock

* Was not supposed to add those back
This commit is contained in:
gdkchan 2022-05-05 14:58:59 -03:00 committed by GitHub
parent 39bdf6d41e
commit 54deded929
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 226 additions and 59 deletions

View file

@ -30,6 +30,7 @@ namespace Ryujinx.HLE.HOS.Services.Ro
private List<NroInfo> _nroInfos; private List<NroInfo> _nroInfos;
private KProcess _owner; private KProcess _owner;
private IVirtualMemoryManager _ownerMm;
private static Random _random = new Random(); private static Random _random = new Random();
@ -38,6 +39,7 @@ namespace Ryujinx.HLE.HOS.Services.Ro
_nrrInfos = new List<NrrInfo>(MaxNrr); _nrrInfos = new List<NrrInfo>(MaxNrr);
_nroInfos = new List<NroInfo>(MaxNro); _nroInfos = new List<NroInfo>(MaxNro);
_owner = null; _owner = null;
_ownerMm = null;
} }
private ResultCode ParseNrr(out NrrInfo nrrInfo, ServiceCtx context, ulong nrrAddress, ulong nrrSize) private ResultCode ParseNrr(out NrrInfo nrrInfo, ServiceCtx context, ulong nrrAddress, ulong nrrSize)
@ -564,10 +566,12 @@ namespace Ryujinx.HLE.HOS.Services.Ro
return ResultCode.InvalidSession; return ResultCode.InvalidSession;
} }
_owner = context.Process.HandleTable.GetKProcess(context.Request.HandleDesc.ToCopy[0]); int processHandle = context.Request.HandleDesc.ToCopy[0];
context.Device.System.KernelContext.Syscall.CloseHandle(context.Request.HandleDesc.ToCopy[0]); _owner = context.Process.HandleTable.GetKProcess(processHandle);
_ownerMm = _owner?.CpuMemory;
context.Device.System.KernelContext.Syscall.CloseHandle(processHandle);
if (_owner?.CpuMemory is IRefCounted rc) if (_ownerMm is IRefCounted rc)
{ {
rc.IncrementReferenceCount(); rc.IncrementReferenceCount();
} }
@ -586,7 +590,7 @@ namespace Ryujinx.HLE.HOS.Services.Ro
_nroInfos.Clear(); _nroInfos.Clear();
if (_owner?.CpuMemory is IRefCounted rc) if (_ownerMm is IRefCounted rc)
{ {
rc.DecrementReferenceCount(); rc.DecrementReferenceCount();
} }

View file

@ -48,7 +48,7 @@ namespace Ryujinx.Memory
{ {
_viewCompatible = flags.HasFlag(MemoryAllocationFlags.ViewCompatible); _viewCompatible = flags.HasFlag(MemoryAllocationFlags.ViewCompatible);
_forceWindows4KBView = flags.HasFlag(MemoryAllocationFlags.ForceWindows4KBViewMapping); _forceWindows4KBView = flags.HasFlag(MemoryAllocationFlags.ForceWindows4KBViewMapping);
_pointer = MemoryManagement.Reserve(size, _viewCompatible); _pointer = MemoryManagement.Reserve(size, _viewCompatible, _forceWindows4KBView);
} }
else else
{ {
@ -404,7 +404,7 @@ namespace Ryujinx.Memory
} }
else else
{ {
MemoryManagement.Free(ptr); MemoryManagement.Free(ptr, Size, _forceWindows4KBView);
} }
foreach (MemoryBlock viewStorage in _viewStorages.Keys) foreach (MemoryBlock viewStorage in _viewStorages.Keys)

View file

@ -8,9 +8,7 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size); return MemoryManagementWindows.Allocate((IntPtr)size);
return MemoryManagementWindows.Allocate(sizeNint);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -22,13 +20,11 @@ namespace Ryujinx.Memory
} }
} }
public static IntPtr Reserve(ulong size, bool viewCompatible) public static IntPtr Reserve(ulong size, bool viewCompatible, bool force4KBMap)
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size); return MemoryManagementWindows.Reserve((IntPtr)size, viewCompatible, force4KBMap);
return MemoryManagementWindows.Reserve(sizeNint, viewCompatible);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -44,9 +40,7 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size); return MemoryManagementWindows.Commit(address, (IntPtr)size);
return MemoryManagementWindows.Commit(address, sizeNint);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -62,9 +56,7 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size); return MemoryManagementWindows.Decommit(address, (IntPtr)size);
return MemoryManagementWindows.Decommit(address, sizeNint);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -80,15 +72,13 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size);
if (force4KBMap) if (force4KBMap)
{ {
MemoryManagementWindows.MapView4KB(sharedMemory, srcOffset, address, sizeNint); MemoryManagementWindows.MapView4KB(sharedMemory, srcOffset, address, (IntPtr)size);
} }
else else
{ {
MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, sizeNint); MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, (IntPtr)size);
} }
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
@ -105,15 +95,13 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size);
if (force4KBMap) if (force4KBMap)
{ {
MemoryManagementWindows.UnmapView4KB(address, sizeNint); MemoryManagementWindows.UnmapView4KB(address, (IntPtr)size);
} }
else else
{ {
MemoryManagementWindows.UnmapView(sharedMemory, address, sizeNint); MemoryManagementWindows.UnmapView(sharedMemory, address, (IntPtr)size);
} }
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
@ -132,15 +120,13 @@ namespace Ryujinx.Memory
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size);
if (forView && force4KBMap) if (forView && force4KBMap)
{ {
result = MemoryManagementWindows.Reprotect4KB(address, sizeNint, permission, forView); result = MemoryManagementWindows.Reprotect4KB(address, (IntPtr)size, permission, forView);
} }
else else
{ {
result = MemoryManagementWindows.Reprotect(address, sizeNint, permission, forView); result = MemoryManagementWindows.Reprotect(address, (IntPtr)size, permission, forView);
} }
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
@ -158,11 +144,11 @@ namespace Ryujinx.Memory
} }
} }
public static bool Free(IntPtr address) public static bool Free(IntPtr address, ulong size, bool force4KBMap)
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
return MemoryManagementWindows.Free(address); return MemoryManagementWindows.Free(address, (IntPtr)size, force4KBMap);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -178,9 +164,7 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
IntPtr sizeNint = new IntPtr((long)size); return MemoryManagementWindows.CreateSharedMemory((IntPtr)size, reserve);
return MemoryManagementWindows.CreateSharedMemory(sizeNint, reserve);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {

View file

@ -7,21 +7,27 @@ namespace Ryujinx.Memory
[SupportedOSPlatform("windows")] [SupportedOSPlatform("windows")]
static class MemoryManagementWindows static class MemoryManagementWindows
{ {
private const int PageSize = 0x1000; public const int PageSize = 0x1000;
private static readonly PlaceholderManager _placeholders = new PlaceholderManager(); private static readonly PlaceholderManager _placeholders = new PlaceholderManager();
private static readonly PlaceholderManager4KB _placeholders4KB = new PlaceholderManager4KB();
public static IntPtr Allocate(IntPtr size) public static IntPtr Allocate(IntPtr size)
{ {
return AllocateInternal(size, AllocationType.Reserve | AllocationType.Commit); return AllocateInternal(size, AllocationType.Reserve | AllocationType.Commit);
} }
public static IntPtr Reserve(IntPtr size, bool viewCompatible) public static IntPtr Reserve(IntPtr size, bool viewCompatible, bool force4KBMap)
{ {
if (viewCompatible) if (viewCompatible)
{ {
IntPtr baseAddress = AllocateInternal2(size, AllocationType.Reserve | AllocationType.ReservePlaceholder); IntPtr baseAddress = AllocateInternal2(size, AllocationType.Reserve | AllocationType.ReservePlaceholder);
if (!force4KBMap)
{
_placeholders.ReserveRange((ulong)baseAddress, (ulong)size); _placeholders.ReserveRange((ulong)baseAddress, (ulong)size);
}
return baseAddress; return baseAddress;
} }
@ -69,6 +75,8 @@ namespace Ryujinx.Memory
public static void MapView4KB(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size) public static void MapView4KB(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size)
{ {
_placeholders4KB.UnmapAndMarkRangeAsMapped(location, size);
ulong uaddress = (ulong)location; ulong uaddress = (ulong)location;
ulong usize = (ulong)size; ulong usize = (ulong)size;
IntPtr endLocation = (IntPtr)(uaddress + usize); IntPtr endLocation = (IntPtr)(uaddress + usize);
@ -105,20 +113,7 @@ namespace Ryujinx.Memory
public static void UnmapView4KB(IntPtr location, IntPtr size) public static void UnmapView4KB(IntPtr location, IntPtr size)
{ {
ulong uaddress = (ulong)location; _placeholders4KB.UnmapView(location, size);
ulong usize = (ulong)size;
IntPtr endLocation = (IntPtr)(uaddress + usize);
while (location != endLocation)
{
bool result = WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, location, 2);
if (!result)
{
throw new WindowsApiException("UnmapViewOfFile2");
}
location += PageSize;
}
} }
public static bool Reprotect(IntPtr address, IntPtr size, MemoryPermission permission, bool forView) public static bool Reprotect(IntPtr address, IntPtr size, MemoryPermission permission, bool forView)
@ -151,8 +146,17 @@ namespace Ryujinx.Memory
return true; return true;
} }
public static bool Free(IntPtr address) public static bool Free(IntPtr address, IntPtr size, bool force4KBMap)
{ {
if (force4KBMap)
{
_placeholders4KB.UnmapRange(address, size);
}
else
{
_placeholders.UnmapView(IntPtr.Zero, address, size);
}
return WindowsApi.VirtualFree(address, IntPtr.Zero, AllocationType.Release); return WindowsApi.VirtualFree(address, IntPtr.Zero, AllocationType.Release);
} }

View file

@ -1,5 +1,6 @@
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.Runtime.Versioning;
using System.Threading; using System.Threading;
namespace Ryujinx.Memory.WindowsShared namespace Ryujinx.Memory.WindowsShared
@ -7,6 +8,7 @@ namespace Ryujinx.Memory.WindowsShared
/// <summary> /// <summary>
/// Windows memory placeholder manager. /// Windows memory placeholder manager.
/// </summary> /// </summary>
[SupportedOSPlatform("windows")]
class PlaceholderManager class PlaceholderManager
{ {
private const ulong MinimumPageSize = 0x1000; private const ulong MinimumPageSize = 0x1000;
@ -203,7 +205,7 @@ namespace Ryujinx.Memory.WindowsShared
ulong endAddress = startAddress + unmapSize; ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, ulong>>(); var overlaps = Array.Empty<IntervalTreeNode<ulong, ulong>>();
int count = 0; int count;
lock (_mappings) lock (_mappings)
{ {
@ -226,8 +228,11 @@ namespace Ryujinx.Memory.WindowsShared
ulong overlapEnd = overlap.End; ulong overlapEnd = overlap.End;
ulong overlapValue = overlap.Value; ulong overlapValue = overlap.Value;
lock (_mappings)
{
_mappings.Remove(overlap); _mappings.Remove(overlap);
_mappings.Add(overlapStart, overlapEnd, ulong.MaxValue); _mappings.Add(overlapStart, overlapEnd, ulong.MaxValue);
}
bool overlapStartsBefore = overlapStart < startAddress; bool overlapStartsBefore = overlapStart < startAddress;
bool overlapEndsAfter = overlapEnd > endAddress; bool overlapEndsAfter = overlapEnd > endAddress;
@ -364,7 +369,7 @@ namespace Ryujinx.Memory.WindowsShared
ulong endAddress = reprotectAddress + reprotectSize; ulong endAddress = reprotectAddress + reprotectSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, ulong>>(); var overlaps = Array.Empty<IntervalTreeNode<ulong, ulong>>();
int count = 0; int count;
lock (_mappings) lock (_mappings)
{ {
@ -534,7 +539,7 @@ namespace Ryujinx.Memory.WindowsShared
{ {
ulong endAddress = address + size; ulong endAddress = address + size;
var overlaps = Array.Empty<IntervalTreeNode<ulong, MemoryPermission>>(); var overlaps = Array.Empty<IntervalTreeNode<ulong, MemoryPermission>>();
int count = 0; int count;
lock (_protections) lock (_protections)
{ {

View file

@ -0,0 +1,170 @@
using System;
using System.Runtime.Versioning;
namespace Ryujinx.Memory.WindowsShared
{
/// <summary>
/// Windows 4KB memory placeholder manager.
/// </summary>
[SupportedOSPlatform("windows")]
class PlaceholderManager4KB
{
private const int PageSize = MemoryManagementWindows.PageSize;
private readonly IntervalTree<ulong, byte> _mappings;
/// <summary>
/// Creates a new instance of the Windows 4KB memory placeholder manager.
/// </summary>
public PlaceholderManager4KB()
{
_mappings = new IntervalTree<ulong, byte>();
}
/// <summary>
/// Unmaps the specified range of memory and marks it as mapped internally.
/// </summary>
/// <remarks>
/// Since this marks the range as mapped, the expectation is that the range will be mapped after calling this method.
/// </remarks>
/// <param name="location">Memory address to unmap and mark as mapped</param>
/// <param name="size">Size of the range in bytes</param>
public void UnmapAndMarkRangeAsMapped(IntPtr location, IntPtr size)
{
ulong startAddress = (ulong)location;
ulong unmapSize = (ulong)size;
ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
int count = 0;
lock (_mappings)
{
count = _mappings.Get(startAddress, endAddress, ref overlaps);
}
for (int index = 0; index < count; index++)
{
var overlap = overlaps[index];
// Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong overlapStart = overlap.Start;
ulong overlapEnd = overlap.End;
ulong overlapValue = overlap.Value;
_mappings.Remove(overlap);
ulong unmapStart = Math.Max(overlapStart, startAddress);
ulong unmapEnd = Math.Min(overlapEnd, endAddress);
if (overlapStart < startAddress)
{
startAddress = overlapStart;
}
if (overlapEnd > endAddress)
{
endAddress = overlapEnd;
}
ulong currentAddress = unmapStart;
while (currentAddress < unmapEnd)
{
WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
currentAddress += PageSize;
}
}
_mappings.Add(startAddress, endAddress, 0);
}
/// <summary>
/// Unmaps views at the specified memory range.
/// </summary>
/// <param name="location">Address of the range</param>
/// <param name="size">Size of the range in bytes</param>
public void UnmapView(IntPtr location, IntPtr size)
{
ulong startAddress = (ulong)location;
ulong unmapSize = (ulong)size;
ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
int count = 0;
lock (_mappings)
{
count = _mappings.Get(startAddress, endAddress, ref overlaps);
}
for (int index = 0; index < count; index++)
{
var overlap = overlaps[index];
// Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong overlapStart = overlap.Start;
ulong overlapEnd = overlap.End;
_mappings.Remove(overlap);
if (overlapStart < startAddress)
{
_mappings.Add(overlapStart, startAddress, 0);
}
if (overlapEnd > endAddress)
{
_mappings.Add(endAddress, overlapEnd, 0);
}
ulong unmapStart = Math.Max(overlapStart, startAddress);
ulong unmapEnd = Math.Min(overlapEnd, endAddress);
ulong currentAddress = unmapStart;
while (currentAddress < unmapEnd)
{
WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
currentAddress += PageSize;
}
}
}
/// <summary>
/// Unmaps mapped memory at a given range.
/// </summary>
/// <param name="location">Address of the range</param>
/// <param name="size">Size of the range in bytes</param>
public void UnmapRange(IntPtr location, IntPtr size)
{
ulong startAddress = (ulong)location;
ulong unmapSize = (ulong)size;
ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
int count = 0;
lock (_mappings)
{
count = _mappings.Get(startAddress, endAddress, ref overlaps);
}
for (int index = 0; index < count; index++)
{
var overlap = overlaps[index];
// Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong unmapStart = Math.Max(overlap.Start, startAddress);
ulong unmapEnd = Math.Min(overlap.End, endAddress);
_mappings.Remove(overlap);
ulong currentAddress = unmapStart;
while (currentAddress < unmapEnd)
{
WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
currentAddress += PageSize;
}
}
}
}
}