diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index aa7881b30aa..d638307b320 100755
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -14,6 +14,7 @@ load(
     "tf_cc_test",
     "tf_copts",
     "tf_cuda_library",
+    "if_not_windows",
 )
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load(
@@ -605,8 +606,9 @@ cc_library(
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
+    ] + if_not_windows([
         "@nvtx_archive//:nvtx",
-    ] + if_cuda_is_configured([
+    ]) + if_cuda_is_configured([
         "//tensorflow/stream_executor/cuda:cuda_stream",
         "//tensorflow/core/platform/default/build_config:cublas_plugin",
         "//tensorflow/core/platform/default/build_config:cudnn_plugin",
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 18b7713fe2a..10984b31462 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3287,9 +3287,10 @@ tf_cuda_library(
         "//third_party/eigen3",
         "//tensorflow/core/grappler/utils:functions",
         "//tensorflow/core/profiler/lib:traceme",
-        "@nvtx_archive//:nvtx",
         "//tensorflow/core/profiler/internal:traceme_recorder",
-    ] + mkl_deps(),
+    ] + if_not_windows([
+        "@nvtx_archive//:nvtx",
+    ]) + mkl_deps(),
     alwayslink = 1,
 )
 
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index a2539b0a984..ebc756bb0fc 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -3,6 +3,7 @@ load(
     "tf_cc_test",
     "tf_copts",
     "tf_cuda_library",
+    "if_not_windows",
 )
 load(
     "//third_party/mkl:build_defs.bzl",
@@ -203,9 +204,10 @@ tf_cuda_library(
             "//tensorflow/core:protos_all_cc",
             "//tensorflow/core/profiler/lib:traceme",
             "//tensorflow/core/grappler/optimizers:meta_optimizer",
-            "@nvtx_archive//:nvtx",
         ],
-    }),
+    }) + if_not_windows([
+        "@nvtx_archive//:nvtx",
+    ]),
 )
 
 tf_cc_test(
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index ecc62c5d70e..a82effd2bb2 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1086,7 +1086,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
                                    data_format_str);
   }
   const int rank =
-      (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
+      (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
   ShapeHandle x;
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
 
@@ -1155,7 +1155,7 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
                                    data_format_str);
   }
   const int rank =
-      (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
+      (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
   ShapeHandle y_backprop;
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
   ShapeHandle x;
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc
index f446fb23f62..cd487b0fbe8 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc
@@ -81,7 +81,7 @@ inline bool NumConvOnDeviceWithDataTypeOverThreshold(
 
   for (const auto& node : context.graph_view->GetNodes()) {
     const auto* node_def = node.node();
-    if (!IsConv2D(*node_def) and !IsConv3D(*node_def)) {
+    if (!IsConv2D(*node_def) && !IsConv3D(*node_def)) {
       continue;
     }
     const string& device_name =
@@ -401,7 +401,7 @@ Status PrintDebugLogs(string suffix, GraphDef* graph_) {
   TF_RETURN_IF_ERROR(ReadBoolFromEnvVar(
       "TF_ENABLE_LAYOUT_OPTIMIZE_GRAPH_REWRITE_LOG", /*default_value=*/false,
       &allow_print));
-  if (not allow_print) return Status::OK();
+  if (!allow_print) return Status::OK();
 
   string prepend_path = "/tmp/logs/";
   if (prepend_path.empty()) return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index 9c848d5b868..f6734d7c5bd 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -292,7 +292,7 @@ Status Transposer::CreateConstPermNode(TransposeContext* context,
   node.mutable_attr()->insert({"dtype", attr_data_type});
 
   AttrValue attr_tensor;
-  Tensor tensor(DT_INT32, TensorShape({permutation.size()}));
+  Tensor tensor(DT_INT32, TensorShape({(long long)permutation.size()}));
   for (int i = 0; i < permutation.size(); i++) {
     tensor.flat<int>()(i) = permutation[i];
   }
@@ -728,7 +728,7 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
     TransposeContext* context, utils::MutableNodeView* node) {
   DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
   const int rank = GetFanoutPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
@@ -748,7 +748,7 @@ Status BiasAddTransposer::TransposeNode(
     TransposeContext* context, utils::MutableNodeView* node) {
   DCHECK(IsBiasAdd(*node->node()));
   const int rank = GetFanoutPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   if (!ShouldProcess(*context, *node)) {
@@ -789,7 +789,7 @@ Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
                                             utils::MutableNodeView* node) {
   DCHECK(IsBiasAddGrad(*node->node()));
   const int rank = GetFaninPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   if (!ShouldProcess(*context, *node)) {
@@ -962,7 +962,7 @@ Status FusedBatchNormGradTransposer::TransposeNode(
     TransposeContext* context, utils::MutableNodeView* node) {
   DCHECK(IsFusedBatchNormGrad(*node->node()));
   const int rank = GetFanoutPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
@@ -1335,7 +1335,7 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
                                          utils::MutableNodeView* node) {
   DCHECK(IsConcat(*node->node()));
   const int rank = GetFanoutPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
@@ -1518,7 +1518,7 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
                                        utils::MutableNodeView* node) {
   DCHECK(IsReduceOp(*node->node()));
   const int rank = GetFaninPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
@@ -1591,7 +1591,7 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context,
                                       utils::MutableNodeView* node) {
   DCHECK(IsShape(*node->node()));
   const int rank = GetFaninPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
@@ -1636,7 +1636,7 @@ Status SliceTransposer::TransposeNode(TransposeContext* context,
                                       utils::MutableNodeView* node) {
   DCHECK(IsSlice(*node->node()));
   const int rank = GetFanoutPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
@@ -1907,7 +1907,7 @@ Status UnaryGradTransposer::TransposeNode(TransposeContext* context,
                                           utils::MutableNodeView* node) {
   DCHECK(IsUnaryGrad(*node->node()));
   const int rank = GetFanoutPortRank(*node, 0);
-  if (rank != 4 and rank != 5) {
+  if (rank != 4 && rank != 5) {
     return Status::OK();
   }
   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 25107b1d768..33628531382 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -1284,7 +1284,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
   Status status;
 
   string x_format = fused_node.attr().at(kDataFormat).s();
-  if (x_format == "NCHW" or x_format == "NCDHW") {
+  if (x_format == "NCHW" || x_format == "NCDHW") {
     // Need to reshape the last 4 inputs
     NodeDef new_shape;
     const string new_shape_name =
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index be618180c2f..54e8189ee3c 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -1035,7 +1035,7 @@ class FusedBatchNormOpBase : public OpKernel {
     const Tensor& side_input =
         has_side_input_ ? context->input(5) : empty_side_input_;
 
-    OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
+    OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
                 errors::InvalidArgument("input must be 4 or 5-dimensional",
                                         x.shape().DebugString()));
     OP_REQUIRES(context, scale.dims() == 1,
@@ -1209,10 +1209,10 @@ class FusedBatchNormGradOpBase : public OpKernel {
     // saves inverted variance.
     const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
 
-    OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
+    OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5,
                 errors::InvalidArgument("input must be 4 or 5-dimensional",
                                         y_backprop.shape().DebugString()));
-    OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
+    OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
                 errors::InvalidArgument("input must be 4 or 5-dimensional",
                                         x.shape().DebugString()));
     OP_REQUIRES(context, scale.dims() == 1,
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc
index b4c6c706ff3..48179422eb2 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc
@@ -149,27 +149,6 @@ __device__ EIGEN_STRONG_INLINE void ClearBit(T* bit_mask, int bit) {
   atomicAnd(bit_mask + bin, ~(T(1) << (bit & kRemainderMask)));
 }
 
-__global__ void FlipBoxes(Box* boxes, const int* num_batch_boxes,
-                          const int* box_strides, const int batch_size) {
-  // for (int b = 0; b < batch_size; ++b) {
-  // int box_offset = box_strides[b];
-  for (const int y : CudaGridRangeY(batch_size)) {
-    int box_offset = box_strides[y];
-    Box* curr_boxes = boxes + box_offset;
-    // if (threadIdx.x == 0) {
-    //   printf(" FBx batch=%d, box_offset=%d, num_batch_boxes=%d boxes@ %p \n",
-    //   y,
-    //          box_offset, num_batch_boxes[y],curr_boxes);
-    // }
-
-    for (int i : GpuGridRangeX(num_batch_boxes[y])) {
-      Flipped<true>(curr_boxes[i]);
-    }
-  }
-  // }
-}
-
-
 // Produce a global bitmask (result_mask) of selected boxes from bitmask
 // generated by NMSKernel Abort early if max_boxes boxes are selected.
 // Bitmask is num_boxes*bit_mask_len bits indicating whether to keep or
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 38f3475c58e..be0ec72a36f 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -44,7 +44,7 @@ namespace internal {
 // Eventually absl::strings will have native support for this and we will be
 // able to completely remove PrepareForStrCat().
 template <typename T>
-typename std::enable_if<!std::is_constructible<strings::AlphaNum, T>::value,
+typename std::enable_if<!std::is_convertible<T, strings::AlphaNum>::value,
                         string>::type
 PrepareForStrCat(const T& t) {
   std::stringstream ss;
diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc
index ad45878cee8..4440fb5f143 100644
--- a/tensorflow/core/lib/io/path.cc
+++ b/tensorflow/core/lib/io/path.cc
@@ -35,6 +35,8 @@ namespace tensorflow {
 namespace io {
 namespace internal {
 
+const char kPathSep[] = "/";
+
 string JoinPathImpl(std::initializer_list<StringPiece> paths) {
   string result;
 
@@ -46,18 +48,12 @@ string JoinPathImpl(std::initializer_list<StringPiece> paths) {
       continue;
     }
 
-    if (result[result.size() - 1] == '/') {
-      if (IsAbsolutePath(path)) {
-        strings::StrAppend(&result, path.substr(1));
-      } else {
-        strings::StrAppend(&result, path);
-      }
+    if (IsAbsolutePath(path)) path = path.substr(1);
+
+    if (result[result.size() - 1] == kPathSep[0]) {
+      strings::StrAppend(&result, path);
     } else {
-      if (IsAbsolutePath(path)) {
-        strings::StrAppend(&result, path);
-      } else {
-        strings::StrAppend(&result, "/", path);
-      }
+      strings::StrAppend(&result, kPathSep, path);
     }
   }
 
@@ -126,6 +122,7 @@ bool FixBazelEnvPath(const char* path, string* out) {
 
   return true;
 }
+
 }  // namespace internal
 
 bool IsAbsolutePath(StringPiece path) {
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index ac91b79a07f..75e5b31f3ff 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -570,10 +570,4 @@ Status ReadTextOrBinaryProto(Env* env, const string& fname,
   return ReadBinaryProto(env, fname, proto);
 }
 
-int setenv(const char* name, const char* value, int overwrite) {
-  return ::setenv(name, value, overwrite);
-}
-
-int unsetenv(const char* name) { return ::unsetenv(name); }
-
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/nvtx.h b/tensorflow/core/platform/nvtx.h
index ae0802df456..c951ec1cc86 100755
--- a/tensorflow/core/platform/nvtx.h
+++ b/tensorflow/core/platform/nvtx.h
@@ -16,7 +16,11 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_PLATFORM_NVTX_H_
 #define TENSORFLOW_CORE_PLATFORM_NVTX_H_
 
+#ifdef _WIN32
+#include "cuda/include/nvtx3/nvToolsExt.h"
+#else
 #include "third_party/nvtx3/nvToolsExt.h"
+#endif
 
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/attr_value_util.h"
diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc
index ba2a979df16..a9975f66602 100644
--- a/tensorflow/core/platform/posix/env.cc
+++ b/tensorflow/core/platform/posix/env.cc
@@ -18,6 +18,7 @@ limitations under the License.
 #include <fcntl.h>
 #include <fnmatch.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <sys/mman.h>
 #include <sys/stat.h>
 #include <sys/time.h>
@@ -258,4 +259,10 @@ void PosixEnv::GetLocalTempDirectories(std::vector<string>* list) {
   }
 }
 
+int setenv(const char* name, const char* value, int overwrite) {
+  return ::setenv(name, value, overwrite);
+}
+
+int unsetenv(const char* name) { return ::unsetenv(name); }
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 54be76375c9..ade51b65e8b 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -490,7 +490,7 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
 }
 
 template <typename InType, typename OutType, typename Functor>
-void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
+void BinaryUFunc(char** args, const npy_intp* dimensions, const npy_intp* steps,
                  void* data) {
   const char* i0 = args[0];
   const char* i1 = args[1];
@@ -505,11 +505,17 @@ void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
   }
 }
 
+// Numpy changed const-ness of PyUFuncGenericFunction, provide overload.
 template <typename Functor>
 void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
                   void* data) {
   BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
 }
+template <typename Functor>
+void CompareUFunc(char** args, const npy_intp* dimensions,
+                  const npy_intp* steps, void* data) {
+  BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
+}
 
 struct Bfloat16EqFunctor {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 023f0ce6326..395754581ad 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1278,9 +1278,17 @@ port::Status CheckAndFetchProjectionWeights(
   cudnnDataType_t data_type;
 #if CUDNN_VERSION >= 8000
   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
+      /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+      /*hiddenSize=*/&hidden_size_v,
+      /*numLayers=*/&num_layers_v,
+      /*dropoutDesc=*/&dropout_desc,
+      /*inputMode=*/&input_mode,
+      /*direction=*/&direction,
+      /*mode=*/&mode,
+      /*algo=*/&algo,
+      /*mathPrec=*/&data_type));
 #else
   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
-#endif
       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
       /*hiddenSize=*/&hidden_size_v,
       /*numLayers=*/&num_layers_v,
@@ -1290,6 +1298,7 @@ port::Status CheckAndFetchProjectionWeights(
       /*mode=*/&mode,
       /*algo=*/&algo,
       /*dataType=*/&data_type));
+#endif
   int rec_proj_size_v;
   int out_proj_size_v;
   RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
diff --git a/tensorflow/stream_executor/cuda/cudnn_stub.cc b/tensorflow/stream_executor/cuda/cudnn_stub.cc
index 073ba3ffd00..e30f749897e 100644
--- a/tensorflow/stream_executor/cuda/cudnn_stub.cc
+++ b/tensorflow/stream_executor/cuda/cudnn_stub.cc
@@ -53,7 +53,8 @@ cudnnStatus_t GetSymbolNotFoundError() { return CUDNN_STATUS_INTERNAL_ERROR; }
 #include "tensorflow/stream_executor/cuda/cudnn_6_0.inc"
 #elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 1
 #include "tensorflow/stream_executor/cuda/cudnn_7_0.inc"
-#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 3
+// 2 instead of 3: see https://github.com/tensorflow/tensorflow/issues/32350
+#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 2
 #include "tensorflow/stream_executor/cuda/cudnn_7_1.inc"
 #elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 4
 #include "tensorflow/stream_executor/cuda/cudnn_7_3.inc"
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 382e4bd1fd2..8b4fafdd023 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -73,7 +73,6 @@ class AlgorithmDesc;
 
 class StreamExecutor;
 class ScratchAllocator;
-enum BatchNormalizationKind;
 
 // Convert a type to the corresponding QuantizedActivationMode.
 template <typename ElementType>
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8b29610c021..f04bd260ddd 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -168,11 +168,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
         name = "eigen_archive",
         build_file = clean_dep("//third_party:eigen.BUILD"),
         patch_file = clean_dep("//third_party/eigen3:neon_casting_and_gpu_packet.patch"),
-        sha256 = "2f046557f4093becf51b44c6339873c18e2f1ea55c4b3f3a08b7d15a1d9c6e5b",  # SHARED_EIGEN_SHA
-        strip_prefix = "eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced",
+        sha256 = "bacd9508f8a636a616eef363d7f8d0f6da4c87b935132030a03793884a6ab4f1",  # SHARED_EIGEN_SHA
+        strip_prefix = "eigen-8c9976d7f0558fdc8d0be7476c37e5d562332955",
         urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz",
-            "https://gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz",
+            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/8c9976d7f0558fdc8d0be7476c37e5d562332955/eigen-8c9976d7f0558fdc8d0be7476c37e5d562332955.tar.gz",
+            "https://gitlab.com/libeigen/eigen/-/archive/8c9976d7f0558fdc8d0be7476c37e5d562332955/eigen-8c9976d7f0558fdc8d0be7476c37e5d562332955.tar.gz",
         ],
     )
 
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
index bfe7d6c5288..1247e486903 100644
--- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -143,7 +143,7 @@ def InvokeNvcc(argv, log=False):
   nvccopts += undefines
   nvccopts += defines
   nvccopts += m_options
-  nvccopts += ['--compiler-options="' + " ".join(host_compiler_options) + '"']
+  nvccopts += ['--compiler-options=' + ",".join(host_compiler_options)]
   nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files
   # If we don't specify --keep-dir, nvcc will generate intermediate files under TEMP
   # Put them under NVCC_TEMP_DIR instead, then Bazel can ignore files under NVCC_TEMP_DIR during dependency check
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 1bd7141a372..9b625e4278a 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -531,7 +531,10 @@ def lib_name(base_name, cpu_value, version = None, static = False):
             return "lib%s.a" % base_name
         return "lib%s.so%s" % (base_name, version)
     elif cpu_value == "Windows":
-        return "%s.lib" % base_name
+        if base_name == "nvToolsExt":
+            return "lib/x64/nvToolsExt64_1.lib"
+        else:
+            return "%s.lib" % base_name
     elif cpu_value == "Darwin":
         if static:
             return "lib%s.a" % base_name
@@ -669,7 +672,7 @@ def _find_libs(repository_ctx, cuda_config):
             "nvToolsExt",
             repository_ctx,
             cpu_value,
-            cuda_config.config["cuda_library_dir"],
+            cuda_config.nvToolsExt_path,
             "1",
         ),
         "cupti": _find_cuda_lib(
@@ -762,6 +765,11 @@ def _get_cuda_config(repository_ctx):
         cufft_version = cuda_version
         cusparse_version = cuda_version
 
+    if cpu_value == "Windows":
+        nvToolsExt_path = repository_ctx.os.environ.get("NVTOOLSEXT_PATH", "C:/Program Files/NVIDIA Corporation/NvToolsExt/")
+    else:
+        nvToolsExt_path = toolkit_path
+
     return struct(
         cuda_toolkit_path = toolkit_path,
         cuda_version = cuda_version,
@@ -775,6 +783,7 @@ def _get_cuda_config(repository_ctx):
         compute_capabilities = compute_capabilities(repository_ctx),
         cpu_value = cpu_value,
         config = config,
+        nvToolsExt_path=nvToolsExt_path,
     )
 
 def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
@@ -1148,7 +1157,8 @@ def _create_local_cuda_repository(repository_ctx):
         out_dir = "cuda/bin",
     ))
 
-    if [int(x) for x in cuda_config.cudnn_version.split(".")] < [8, 0]:
+    # Select the headers based on the cuDNN version (strip '64_' for Windows).
+    if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8":
       cudnn_headers = ["cudnn.h"]
     else:
       cudnn_headers = ["cudnn_adv_infer.h",