VoteAllEqual, FindLSB/MSB

This commit is contained in:
Isaac Marovitz 2024-06-22 14:38:09 +01:00 committed by Isaac Marovitz
parent b094d34575
commit a71b5f1a3a
13 changed files with 101 additions and 23 deletions

View file

@ -1,3 +1,4 @@
using Ryujinx.Common;
using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation; using Ryujinx.Graphics.Shader.Translation;
@ -57,6 +58,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
context.AppendLine(); context.AppendLine();
DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values); DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values);
DeclareBufferStructures(context, context.Properties.StorageBuffers.Values); DeclareBufferStructures(context, context.Properties.StorageBuffers.Values);
if ((info.HelperFunctionsMask & HelperFunctionsMask.FindLSB) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.FindMSBS32) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.FindMSBU32) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal");
}
} }
static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind) static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind)
@ -310,5 +326,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
} }
} }
} }
private static void AppendHelperFunction(CodeGenContext context, string filename)
{
string code = EmbeddedResources.ReadAllText(filename);
code = code.Replace("\t", CodeGenContext.Tab);
context.AppendLine(code);
context.AppendLine();
}
} }
} }

View file

@ -0,0 +1,5 @@
template<typename T>
inline T findLSB(T x)
{
return select(ctz(x), T(-1), x == T(0));
}

View file

@ -0,0 +1,5 @@
template<typename T>
inline T findMSBS32(T x)
{
return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));
}

View file

@ -0,0 +1,6 @@
template<typename T>
inline T findMSBU32(T x)
{
T v = select(x, T(-1) - x, x < T(0));
return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));
}

View file

@ -2,6 +2,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
static class HelperFunctionNames static class HelperFunctionNames
{ {
public static string SwizzleAdd = "helperSwizzleAdd"; public static string FindLSB = "findLSB";
public static string FindMSBS32 = "findMSBS32";
public static string FindMSBU32 = "findMSBU32";
} }
} }

View file

@ -1,4 +0,0 @@
inline bool voteAllEqual(bool value)
{
return simd_all(value) || !simd_any(value);
}

View file

@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation;
using System; using System;
using System.Text; using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBarrier;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
@ -123,19 +124,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
case Instruction.Ballot: case Instruction.Ballot:
return Ballot(context, operation); return Ballot(context, operation);
case Instruction.Barrier: case Instruction.Barrier:
return "threadgroup_barrier(mem_flags::mem_threadgroup)"; return Barrier(context, operation);
case Instruction.Call: case Instruction.Call:
return Call(context, operation); return Call(context, operation);
case Instruction.FSIBegin: case Instruction.FSIBegin:
return "|| FSI BEGIN ||"; return "|| FSI BEGIN ||";
case Instruction.FSIEnd: case Instruction.FSIEnd:
return "|| FSI END ||"; return "|| FSI END ||";
case Instruction.FindLSB:
return "|| FIND LSB ||";
case Instruction.FindMSBS32:
return "|| FIND MSB S32 ||";
case Instruction.FindMSBU32:
return "|| FIND MSB U32 ||";
case Instruction.GroupMemoryBarrier: case Instruction.GroupMemoryBarrier:
return "|| FIND GROUP MEMORY BARRIER ||"; return "|| FIND GROUP MEMORY BARRIER ||";
case Instruction.ImageLoad: case Instruction.ImageLoad:
@ -152,6 +147,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
return "|| MEMORY BARRIER ||"; return "|| MEMORY BARRIER ||";
case Instruction.Store: case Instruction.Store:
return Store(context, operation); return Store(context, operation);
case Instruction.SwizzleAdd:
return "|| SWIZZLE ADD ||";
case Instruction.TextureSample: case Instruction.TextureSample:
return TextureSample(context, operation); return TextureSample(context, operation);
case Instruction.TextureQuerySamples: case Instruction.TextureQuerySamples:
@ -165,7 +162,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
case Instruction.VectorExtract: case Instruction.VectorExtract:
return VectorExtract(context, operation); return VectorExtract(context, operation);
case Instruction.VoteAllEqual: case Instruction.VoteAllEqual:
return "|| VOTE ALL EQUAL ||"; return VoteAllEqual(context, operation);
} }
} }

View file

@ -17,5 +17,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
return $"uint4(as_type<uint2>((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}"; return $"uint4(as_type<uint2>((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}";
} }
public static string VoteAllEqual(CodeGenContext context, AstOperation operation)
{
AggregateType dstType = GetSrcVarType(operation.Inst, 0);
string arg = GetSourceExpr(context, operation.GetSource(0), dstType);
return $"simd_all({arg}) || !simd_any({arg})";
}
} }
} }

View file

@ -0,0 +1,16 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenBarrier
{
public static string Barrier(CodeGenContext context, AstOperation operation)
{
return "threadgroup_barrier(mem_flags::mem_threadgroup)";
}
}
}

View file

@ -71,10 +71,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.ExponentB2, InstType.CallUnary, "exp2"); Add(Instruction.ExponentB2, InstType.CallUnary, "exp2");
Add(Instruction.FSIBegin, InstType.Special); Add(Instruction.FSIBegin, InstType.Special);
Add(Instruction.FSIEnd, InstType.Special); Add(Instruction.FSIEnd, InstType.Special);
// TODO: LSB and MSB Implementations https://github.com/KhronosGroup/SPIRV-Cross/blob/bccaa94db814af33d8ef05c153e7c34d8bd4d685/reference/shaders-msl-no-opt/asm/comp/bitscan.asm.comp#L8 Add(Instruction.FindLSB, InstType.CallUnary, HelperFunctionNames.FindLSB);
Add(Instruction.FindLSB, InstType.Special); Add(Instruction.FindMSBS32, InstType.CallUnary, HelperFunctionNames.FindMSBS32);
Add(Instruction.FindMSBS32, InstType.Special); Add(Instruction.FindMSBU32, InstType.CallUnary, HelperFunctionNames.FindMSBU32);
Add(Instruction.FindMSBU32, InstType.Special);
Add(Instruction.Floor, InstType.CallUnary, "floor"); Add(Instruction.Floor, InstType.CallUnary, "floor");
Add(Instruction.FusedMultiplyAdd, InstType.CallTernary, "fma"); Add(Instruction.FusedMultiplyAdd, InstType.CallTernary, "fma");
Add(Instruction.GroupMemoryBarrier, InstType.Special); Add(Instruction.GroupMemoryBarrier, InstType.Special);
@ -117,7 +116,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special); Add(Instruction.Store, InstType.Special);
Add(Instruction.Subtract, InstType.OpBinary, "-", 2); Add(Instruction.Subtract, InstType.OpBinary, "-", 2);
Add(Instruction.SwizzleAdd, InstType.CallTernary, HelperFunctionNames.SwizzleAdd); Add(Instruction.SwizzleAdd, InstType.Special);
Add(Instruction.TextureSample, InstType.Special); Add(Instruction.TextureSample, InstType.Special);
Add(Instruction.TextureQuerySamples, InstType.Special); Add(Instruction.TextureQuerySamples, InstType.Special);
Add(Instruction.TextureQuerySize, InstType.Special); Add(Instruction.TextureQuerySize, InstType.Special);

View file

@ -16,6 +16,8 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\VoteAllEqual.metal" /> <EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindLSB.metal" />
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindMSBS32.metal" />
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindMSBU32.metal" />
</ItemGroup> </ItemGroup>
</Project> </Project>

View file

@ -7,6 +7,11 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{ {
MultiplyHighS32 = 1 << 2, MultiplyHighS32 = 1 << 2,
MultiplyHighU32 = 1 << 3, MultiplyHighU32 = 1 << 3,
FindLSB = 1 << 5,
FindMSBS32 = 1 << 6,
FindMSBU32 = 1 << 7,
SwizzleAdd = 1 << 10, SwizzleAdd = 1 << 10,
FSI = 1 << 11, FSI = 1 << 11,
} }

View file

@ -321,8 +321,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
} }
// Those instructions needs to be emulated by using helper functions, // Those instructions needs to be emulated by using helper functions,
// because they are NVIDIA specific. Those flags helps the backend to // because they are NVIDIA specific or because the target language has
// decide which helper functions are needed on the final generated code. // no direct equivalent. Those flags helps the backend to decide which
// helper functions are needed on the final generated code.
switch (operation.Inst) switch (operation.Inst)
{ {
case Instruction.MultiplyHighS32: case Instruction.MultiplyHighS32:
@ -331,6 +332,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
case Instruction.MultiplyHighU32: case Instruction.MultiplyHighU32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32; context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32;
break; break;
case Instruction.FindLSB:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindLSB;
break;
case Instruction.FindMSBS32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindMSBS32;
break;
case Instruction.FindMSBU32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindMSBU32;
break;
case Instruction.SwizzleAdd: case Instruction.SwizzleAdd:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd; context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd;
break; break;