From 87ea3157871084393ced33fda426dbabc45157e3 Mon Sep 17 00:00:00 2001 From: yell0wsuit <5692900+yell0wsuit@users.noreply.github.com> Date: Tue, 16 Apr 2024 23:02:20 +0700 Subject: [PATCH] Refactor part 2 Replace `WebClient` with `HttpClient`, plus updating DoUpdateWithSingleThreadWorker and DoUpdateWithMultipleThreads --- src/Ryujinx/Modules/Updater/Updater.cs | 255 ++++++++++++------------- 1 file changed, 126 insertions(+), 129 deletions(-) diff --git a/src/Ryujinx/Modules/Updater/Updater.cs b/src/Ryujinx/Modules/Updater/Updater.cs index a6481ba12..88f02ffa5 100644 --- a/src/Ryujinx/Modules/Updater/Updater.cs +++ b/src/Ryujinx/Modules/Updater/Updater.cs @@ -19,6 +19,7 @@ using System.IO; using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Http.Headers; using System.Net.NetworkInformation; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -186,11 +187,13 @@ namespace Ryujinx.Modules // Fetch build size information to learn chunk sizes. try { - httpClient.DefaultRequestHeaders.Add("Range", "bytes=0-0"); + httpClient.DefaultRequestHeaders.Range = new RangeHeaderValue(0, 0); HttpResponseMessage message = await httpClient.GetAsync(new Uri(_buildUrl), HttpCompletionOption.ResponseHeadersRead); _buildSize = message.Content.Headers.ContentRange.Length.Value; + + httpClient.DefaultRequestHeaders.Remove("Range"); } catch (Exception ex) { @@ -251,11 +254,11 @@ namespace Ryujinx.Modules XamlRoot = parent, }; - taskDialog.Opened += (s, e) => + taskDialog.Opened += async (s, e) => { if (_buildSize >= 0) { - DoUpdateWithMultipleThreads(taskDialog, downloadUrl, updateFile); + await DoUpdateWithMultipleThreads(taskDialog, downloadUrl, updateFile); } else { @@ -335,79 +338,33 @@ namespace Ryujinx.Modules } } - private static void DoUpdateWithMultipleThreads(TaskDialog taskDialog, string downloadUrl, string updateFile) + private static async Task DoUpdateWithMultipleThreads(TaskDialog taskDialog, string downloadUrl, string updateFile) { - // Multi-Threaded Updater long chunkSize = _buildSize / ConnectionCount; long remainderChunk = _buildSize % ConnectionCount; int completedRequests = 0; - int totalProgressPercentage = 0; int[] progressPercentage = new int[ConnectionCount]; + List chunkDataList = new List(new byte[ConnectionCount][]); - List list = new(ConnectionCount); - List webClients = new(ConnectionCount); + List downloadTasks = new List(); for (int i = 0; i < ConnectionCount; i++) { - list.Add(Array.Empty()); - } + long rangeStart = i * chunkSize; + long rangeEnd = (i == ConnectionCount - 1) ? (rangeStart + chunkSize + remainderChunk - 1) : (rangeStart + chunkSize - 1); + int index = i; - for (int i = 0; i < ConnectionCount; i++) - { -#pragma warning disable SYSLIB0014 - // TODO: WebClient is obsolete and need to be replaced with a more complex logic using HttpClient. - using WebClient client = new(); -#pragma warning restore SYSLIB0014 - - webClients.Add(client); - - if (i == ConnectionCount - 1) + downloadTasks.Add(Task.Run(async () => { - client.Headers.Add("Range", $"bytes={chunkSize * i}-{(chunkSize * (i + 1) - 1) + remainderChunk}"); - } - else - { - client.Headers.Add("Range", $"bytes={chunkSize * i}-{chunkSize * (i + 1) - 1}"); - } + byte[] chunkData = await DownloadFileChunk(downloadUrl, rangeStart, rangeEnd, index, taskDialog, progressPercentage); + chunkDataList[index] = chunkData; - client.DownloadProgressChanged += (_, args) => - { - int index = (int)args.UserState; - - Interlocked.Add(ref totalProgressPercentage, -1 * progressPercentage[index]); - Interlocked.Exchange(ref progressPercentage[index], args.ProgressPercentage); - Interlocked.Add(ref totalProgressPercentage, args.ProgressPercentage); - - taskDialog.SetProgressBarState(totalProgressPercentage / ConnectionCount, TaskDialogProgressState.Normal); - }; - - client.DownloadDataCompleted += (_, args) => - { - int index = (int)args.UserState; - - if (args.Cancelled) - { - webClients[index].Dispose(); - - taskDialog.Hide(); - - return; - } - - list[index] = args.Result; Interlocked.Increment(ref completedRequests); - - if (Equals(completedRequests, ConnectionCount)) + if (Interlocked.Equals(completedRequests, ConnectionCount)) { - byte[] mergedFileBytes = new byte[_buildSize]; - for (int connectionIndex = 0, destinationOffset = 0; connectionIndex < ConnectionCount; connectionIndex++) - { - Array.Copy(list[connectionIndex], 0, mergedFileBytes, destinationOffset, list[connectionIndex].Length); - destinationOffset += list[connectionIndex].Length; - } - - File.WriteAllBytes(updateFile, mergedFileBytes); + byte[] allData = CombineChunks(chunkDataList, _buildSize); + File.WriteAllBytes(updateFile, allData); // On macOS, ensure that we remove the quarantine bit to prevent Gatekeeper from blocking execution. if (OperatingSystem.IsMacOS()) @@ -417,73 +374,108 @@ namespace Ryujinx.Modules xattrProcess.WaitForExit(); } - try + // Ensure that the install update is run on the UI thread. + await Dispatcher.UIThread.InvokeAsync(async () => { - InstallUpdate(taskDialog, updateFile); - } - catch (Exception e) - { - Logger.Warning?.Print(LogClass.Application, e.Message); - Logger.Warning?.Print(LogClass.Application, "Multi-Threaded update failed, falling back to single-threaded updater."); - - DoUpdateWithSingleThread(taskDialog, downloadUrl, updateFile); - } + try + { + await InstallUpdate(taskDialog, updateFile); + } + catch (Exception e) + { + Logger.Warning?.Print(LogClass.Application, e.Message); + Logger.Warning?.Print(LogClass.Application, "Multi-Threaded update failed, falling back to single-threaded updater."); + DoUpdateWithSingleThread(taskDialog, downloadUrl, updateFile); + } + }); } - }; - - try - { - client.DownloadDataAsync(new Uri(downloadUrl), i); - } - catch (WebException ex) - { - Logger.Warning?.Print(LogClass.Application, ex.Message); - Logger.Warning?.Print(LogClass.Application, "Multi-Threaded update failed, falling back to single-threaded updater."); - - foreach (WebClient webClient in webClients) - { - webClient.CancelAsync(); - } - - DoUpdateWithSingleThread(taskDialog, downloadUrl, updateFile); - - return; - } + })); } + + await Task.WhenAll(downloadTasks); } - private static void DoUpdateWithSingleThreadWorker(TaskDialog taskDialog, string downloadUrl, string updateFile) + private static byte[] CombineChunks(List chunks, long totalSize) { - using HttpClient client = new(); - // We do not want to timeout while downloading - client.Timeout = TimeSpan.FromDays(1); - - using HttpResponseMessage response = client.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead).Result; - using Stream remoteFileStream = response.Content.ReadAsStreamAsync().Result; - using Stream updateFileStream = File.Open(updateFile, FileMode.Create); - - long totalBytes = response.Content.Headers.ContentLength.Value; - long byteWritten = 0; - - byte[] buffer = new byte[32 * 1024]; - - while (true) + byte[] data = new byte[totalSize]; + long position = 0; + foreach (byte[] chunk in chunks) { - int readSize = remoteFileStream.Read(buffer); + Buffer.BlockCopy(chunk, 0, data, (int)position, chunk.Length); + position += chunk.Length; + } + return data; + } - if (readSize == 0) + private static async Task DownloadFileChunk(string url, long start, long end, int index, TaskDialog taskDialog, int[] progressPercentage) + { + byte[] buffer = new byte[8192]; + using var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.Range = new RangeHeaderValue(start, end); + HttpResponseMessage response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + using var stream = await response.Content.ReadAsStreamAsync(); + using var memoryStream = new MemoryStream(); + int bytesRead; + long totalRead = 0; + + while ((bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + memoryStream.Write(buffer, 0, bytesRead); + totalRead += bytesRead; + int progress = (int)((totalRead * 100) / (end - start + 1)); + progressPercentage[index] = progress; + + Dispatcher.UIThread.Post(() => { - break; - } - - byteWritten += readSize; - - taskDialog.SetProgressBarState(GetPercentage(byteWritten, totalBytes), TaskDialogProgressState.Normal); - - updateFileStream.Write(buffer, 0, readSize); + taskDialog.SetProgressBarState(progressPercentage.Sum() / ConnectionCount, TaskDialogProgressState.Normal); + }); } - InstallUpdate(taskDialog, updateFile); + return memoryStream.ToArray(); + } + + private static async Task DoUpdateWithSingleThreadWorker(TaskDialog taskDialog, string downloadUrl, string updateFile) + { + // We do not want to timeout while downloading + httpClient.Timeout = TimeSpan.FromDays(1); + + // Use the existing httpClient instance, correctly configured + HttpResponseMessage response = await httpClient.GetAsync(downloadUrl, HttpCompletionOption.ResponseHeadersRead); + if (!response.IsSuccessStatusCode) + { + throw new HttpRequestException($"Failed to download file: {response.ReasonPhrase}"); + } + + long totalBytes = response.Content.Headers.ContentLength ?? 0; + long byteWritten = 0; + + // Ensure the entire content body is read asynchronously + using Stream remoteFileStream = await response.Content.ReadAsStreamAsync(); + using Stream updateFileStream = File.Open(updateFile, FileMode.Create); + + byte[] buffer = new byte[32 * 1024]; + int readSize; + + while ((readSize = await remoteFileStream.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + updateFileStream.Write(buffer, 0, readSize); + byteWritten += readSize; + + int progress = GetPercentage(byteWritten, totalBytes); + Dispatcher.UIThread.Post(() => + { + taskDialog.SetProgressBarState(progress, TaskDialogProgressState.Normal); + }); + } + + await InstallUpdate(taskDialog, updateFile); + } + + private static int GetPercentage(long value, long total) + { + if (total == 0) + return 0; + return (int)((value * 100) / total); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -494,12 +486,10 @@ namespace Ryujinx.Modules private static void DoUpdateWithSingleThread(TaskDialog taskDialog, string downloadUrl, string updateFile) { - Thread worker = new(() => DoUpdateWithSingleThreadWorker(taskDialog, downloadUrl, updateFile)) + Task.Run(async () => { - Name = "Updater.SingleThreadWorker", - }; - - worker.Start(); + await DoUpdateWithSingleThreadWorker(taskDialog, downloadUrl, updateFile); + }); } [SupportedOSPlatform("linux")] @@ -573,11 +563,14 @@ namespace Ryujinx.Modules } } - private static void InstallUpdate(TaskDialog taskDialog, string updateFile) + private static async Task InstallUpdate(TaskDialog taskDialog, string updateFile) { // Extract Update - taskDialog.SubHeader = LocaleManager.Instance[LocaleKeys.UpdaterExtracting]; - taskDialog.SetProgressBarState(0, TaskDialogProgressState.Normal); + await Dispatcher.UIThread.InvokeAsync(() => + { + taskDialog.SubHeader = LocaleManager.Instance[LocaleKeys.UpdaterExtracting]; + taskDialog.SetProgressBarState(0, TaskDialogProgressState.Normal); + }); if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) { @@ -597,8 +590,12 @@ namespace Ryujinx.Modules List allFiles = EnumerateFilesToDelete().ToList(); - taskDialog.SubHeader = LocaleManager.Instance[LocaleKeys.UpdaterRenaming]; - taskDialog.SetProgressBarState(0, TaskDialogProgressState.Normal); + await Dispatcher.UIThread.InvokeAsync(() => + { + taskDialog.SubHeader = LocaleManager.Instance[LocaleKeys.UpdaterRenaming]; + taskDialog.SetProgressBarState(0, TaskDialogProgressState.Normal); + taskDialog.Hide(); + }); // NOTE: On macOS, replacement is delayed to the restart phase. if (!OperatingSystem.IsMacOS()) @@ -612,7 +609,7 @@ namespace Ryujinx.Modules { File.Move(file, file + ".ryuold"); - Dispatcher.UIThread.InvokeAsync(() => + await Dispatcher.UIThread.InvokeAsync(() => { taskDialog.SetProgressBarState(GetPercentage(count, allFiles.Count), TaskDialogProgressState.Normal); }); @@ -623,7 +620,7 @@ namespace Ryujinx.Modules } } - Dispatcher.UIThread.InvokeAsync(() => + await Dispatcher.UIThread.InvokeAsync(() => { taskDialog.SubHeader = LocaleManager.Instance[LocaleKeys.UpdaterAddingFiles]; taskDialog.SetProgressBarState(0, TaskDialogProgressState.Normal);