From 850b08a16cecf260e7c8e07b81b5e0078622974d Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sun, 23 May 2021 04:27:09 -0300
Subject: [PATCH] spirv: Be aware of NAN unaware drivers

---
 .../spirv/emit_spirv_floating_point.cpp       | 58 +++++++++++++------
 1 file changed, 40 insertions(+), 18 deletions(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp
index 97d11cc637..b3afbef259 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp
@@ -22,6 +22,28 @@ Id Clamp(EmitContext& ctx, Id type, Id value, Id zero, Id one) {
         return ctx.OpFClamp(type, value, zero, one);
     }
 }
+
+Id FPOrdNotEqual(EmitContext& ctx, Id lhs, Id rhs) {
+    if (ctx.profile.ignore_nan_fp_comparisons) {
+        const Id comp{ctx.OpFOrdEqual(ctx.U1, lhs, rhs)};
+        const Id lhs_not_nan{ctx.OpLogicalNot(ctx.U1, ctx.OpIsNan(ctx.U1, lhs))};
+        const Id rhs_not_nan{ctx.OpLogicalNot(ctx.U1, ctx.OpIsNan(ctx.U1, rhs))};
+        return ctx.OpLogicalAnd(ctx.U1, ctx.OpLogicalAnd(ctx.U1, comp, lhs_not_nan), rhs_not_nan);
+    } else {
+        return ctx.OpFOrdNotEqual(ctx.U1, lhs, rhs);
+    }
+}
+
+Id FPUnordCompare(Id (EmitContext::*comp_func)(Id, Id, Id), EmitContext& ctx, Id lhs, Id rhs) {
+    if (ctx.profile.ignore_nan_fp_comparisons) {
+        const Id lhs_nan{ctx.OpIsNan(ctx.U1, lhs)};
+        const Id rhs_nan{ctx.OpIsNan(ctx.U1, rhs)};
+        const Id comp{(ctx.*comp_func)(ctx.U1, lhs, rhs)};
+        return ctx.OpLogicalOr(ctx.U1, ctx.OpLogicalOr(ctx.U1, comp, lhs_nan), rhs_nan);
+    } else {
+        return (ctx.*comp_func)(ctx.U1, lhs, rhs);
+    }
+}
 } // Anonymous namespace
 
 Id EmitFPAbs16(EmitContext& ctx, Id value) {
@@ -227,27 +249,27 @@ Id EmitFPOrdEqual64(EmitContext& ctx, Id lhs, Id rhs) {
 }
 
 Id EmitFPUnordEqual16(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordEqual32(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordEqual64(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPOrdNotEqual16(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFOrdNotEqual(ctx.U1, lhs, rhs);
+    return FPOrdNotEqual(ctx, lhs, rhs);
 }
 
 Id EmitFPOrdNotEqual32(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFOrdNotEqual(ctx.U1, lhs, rhs);
+    return FPOrdNotEqual(ctx, lhs, rhs);
 }
 
 Id EmitFPOrdNotEqual64(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFOrdNotEqual(ctx.U1, lhs, rhs);
+    return FPOrdNotEqual(ctx, lhs, rhs);
 }
 
 Id EmitFPUnordNotEqual16(EmitContext& ctx, Id lhs, Id rhs) {
@@ -275,15 +297,15 @@ Id EmitFPOrdLessThan64(EmitContext& ctx, Id lhs, Id rhs) {
 }
 
 Id EmitFPUnordLessThan16(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordLessThan(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordLessThan, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordLessThan32(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordLessThan(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordLessThan, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordLessThan64(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordLessThan(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordLessThan, ctx, lhs, rhs);
 }
 
 Id EmitFPOrdGreaterThan16(EmitContext& ctx, Id lhs, Id rhs) {
@@ -299,15 +321,15 @@ Id EmitFPOrdGreaterThan64(EmitContext& ctx, Id lhs, Id rhs) {
 }
 
 Id EmitFPUnordGreaterThan16(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordGreaterThan(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordGreaterThan, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordGreaterThan32(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordGreaterThan(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordGreaterThan, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordGreaterThan64(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordGreaterThan(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordGreaterThan, ctx, lhs, rhs);
 }
 
 Id EmitFPOrdLessThanEqual16(EmitContext& ctx, Id lhs, Id rhs) {
@@ -323,15 +345,15 @@ Id EmitFPOrdLessThanEqual64(EmitContext& ctx, Id lhs, Id rhs) {
 }
 
 Id EmitFPUnordLessThanEqual16(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordLessThanEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordLessThanEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordLessThanEqual32(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordLessThanEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordLessThanEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordLessThanEqual64(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordLessThanEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordLessThanEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPOrdGreaterThanEqual16(EmitContext& ctx, Id lhs, Id rhs) {
@@ -347,15 +369,15 @@ Id EmitFPOrdGreaterThanEqual64(EmitContext& ctx, Id lhs, Id rhs) {
 }
 
 Id EmitFPUnordGreaterThanEqual16(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordGreaterThanEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordGreaterThanEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordGreaterThanEqual32(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordGreaterThanEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordGreaterThanEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPUnordGreaterThanEqual64(EmitContext& ctx, Id lhs, Id rhs) {
-    return ctx.OpFUnordGreaterThanEqual(ctx.U1, lhs, rhs);
+    return FPUnordCompare(&EmitContext::OpFUnordGreaterThanEqual, ctx, lhs, rhs);
 }
 
 Id EmitFPIsNan16(EmitContext& ctx, Id value) {