using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; namespace Ryujinx.Graphics.Device { public class DeviceState<TState> : IDeviceState where TState : unmanaged { private const int RegisterSize = sizeof(int); public TState State; private readonly BitArray _readableRegisters; private readonly BitArray _writableRegisters; private readonly Dictionary<int, Func<int>> _readCallbacks; private readonly Dictionary<int, Action<int>> _writeCallbacks; public DeviceState(IReadOnlyDictionary<string, RwCallback> callbacks = null) { int size = (Unsafe.SizeOf<TState>() + RegisterSize - 1) / RegisterSize; _readableRegisters = new BitArray(size); _writableRegisters = new BitArray(size); _readCallbacks = new Dictionary<int, Func<int>>(); _writeCallbacks = new Dictionary<int, Action<int>>(); var fields = typeof(TState).GetFields(); int offset = 0; for (int fieldIndex = 0; fieldIndex < fields.Length; fieldIndex++) { var field = fields[fieldIndex]; var regAttr = field.GetCustomAttributes<RegisterAttribute>(false).FirstOrDefault(); int sizeOfField = SizeCalculator.SizeOf(field.FieldType); for (int i = 0; i < ((sizeOfField + 3) & ~3); i += 4) { _readableRegisters[(offset + i) / RegisterSize] = regAttr?.AccessControl.HasFlag(AccessControl.ReadOnly) ?? true; _writableRegisters[(offset + i) / RegisterSize] = regAttr?.AccessControl.HasFlag(AccessControl.WriteOnly) ?? true; } if (callbacks != null && callbacks.TryGetValue(field.Name, out var cb)) { if (cb.Read != null) { _readCallbacks.Add(offset, cb.Read); } if (cb.Write != null) { _writeCallbacks.Add(offset, cb.Write); } } offset += sizeOfField; } Debug.Assert(offset == Unsafe.SizeOf<TState>()); } public virtual int Read(int offset) { if (Check(offset) && _readableRegisters[offset / RegisterSize]) { int alignedOffset = Align(offset); if (_readCallbacks.TryGetValue(alignedOffset, out Func<int> read)) { return read(); } else { return GetRef<int>(alignedOffset); } } return 0; } public virtual void Write(int offset, int data) { if (Check(offset) && _writableRegisters[offset / RegisterSize]) { int alignedOffset = Align(offset); GetRef<int>(alignedOffset) = data; if (_writeCallbacks.TryGetValue(alignedOffset, out Action<int> write)) { write(data); } } } private bool Check(int offset) { return (uint)Align(offset) < Unsafe.SizeOf<TState>(); } public ref T GetRef<T>(int offset) where T : unmanaged { if ((uint)(offset + Unsafe.SizeOf<T>()) > Unsafe.SizeOf<TState>()) { throw new ArgumentOutOfRangeException(nameof(offset)); } return ref Unsafe.As<TState, T>(ref Unsafe.AddByteOffset(ref State, (IntPtr)offset)); } private static int Align(int offset) { return offset & ~(RegisterSize - 1); } } }