Skip to content

Commit 57cf063

Browse files
committed
Patch TensorFlow absl Android build
1 parent 53f608b commit 57cf063

1 file changed

Lines changed: 65 additions & 4 deletions

File tree

scripts/patch_tfjava.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import difflib
34
import sys
45
from pathlib import Path
56

@@ -51,12 +52,72 @@ def patch_workspace(path: Path) -> None:
5152
path.write_text(text, encoding="utf-8")
5253

5354

55+
ORIGINAL_TENSORFLOW_ABSL_PATCH = """--- ./absl/time/internal/cctz/BUILD.bazel\t2019-09-23 13:20:52.000000000 -0700
56+
+++ ./absl/time/internal/cctz/BUILD.bazel.fixed\t2019-09-23 13:20:48.000000000 -0700
57+
@@ -76,15 +76,6 @@
58+
"include/cctz/time_zone.h",
59+
"include/cctz/zone_info_source.h",
60+
],
61+
- linkopts = select({
62+
- ":osx": [
63+
- "-framework Foundation",
64+
- ],
65+
- ":ios": [
66+
- "-framework Foundation",
67+
- ],
68+
- "//conditions:default": [],
69+
- }),
70+
visibility = ["//visibility:public"],
71+
deps = [":civil_time"],
72+
)
73+
--- ./absl/strings/string_view.h\t2019-09-23 13:20:52.000000000 -0700
74+
+++ ./absl/strings/string_view.h.fixed\t2019-09-23 13:20:48.000000000 -0700
75+
@@ -492,7 +492,14 @@
76+
(std::numeric_limits<difference_type>::max)();
77+
78+
static constexpr size_type CheckLengthInternal(size_type len) {
79+
+#if defined(__NVCC__) && (__CUDACC_VER_MAJOR__<10 || (__CUDACC_VER_MAJOR__==10 && __CUDACC_VER_MINOR__<2)) && !defined(NDEBUG)
80+
+ // An nvcc bug treats the original return expression as a non-constant,
81+
+ // which is not allowed in a constexpr function. This only happens when
82+
+ // NDEBUG is not defined. This will be fixed in the CUDA 10.2 release.
83+
+ return len;
84+
+#else
85+
return ABSL_ASSERT(len <= kMaxSize), len;
86+
+#endif
87+
}
88+
89+
const char* ptr_;
90+
"""
91+
92+
93+
ANDROID_GRAPHCYCLES_PATCH_HUNK = """--- ./absl/synchronization/internal/graphcycles.cc\t2019-09-23 13:20:52.000000000 -0700
94+
+++ ./absl/synchronization/internal/graphcycles.cc.fixed\t2019-09-23 13:20:48.000000000 -0700
95+
@@ -35,6 +35,7 @@
96+
#include "absl/synchronization/internal/graphcycles.h"
97+
98+
#include <algorithm>
99+
#include <array>
100+
+#include <limits>
101+
#include "absl/base/internal/hide_ptr.h"
102+
#include "absl/base/internal/raw_logging.h"
103+
#include "absl/base/internal/spinlock.h"
104+
"""
105+
106+
54107
def write_tensorflow_android_absl_patch(path: Path) -> None:
55108
path.parent.mkdir(parents=True, exist_ok=True)
56-
text = """diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
57-
--- a/tensorflow/workspace.bzl
58-
+++ b/tensorflow/workspace.bzl
59-
@@ -186,7 +186,8 @@ def tf_workspace(path_prefix = \"\", tf_repo_name = \"\"):\n tf_http_archive(\n name = \"com_google_absl\",\n build_file = clean_dep(\"//third_party:com_google_absl.BUILD\"),\n # TODO: Remove the patch when https://github.com/abseil/abseil-cpp/issues/326 is resolved\n # and when TensorFlow is build against CUDA 10.2\n- patch_file = clean_dep(\"//third_party:com_google_absl_fix_mac_and_nvcc_build.patch\"),\n+ patch_file = clean_dep(\"//third_party:com_google_absl_fix_mac_and_nvcc_build.patch\"),\n+ patch_cmds = [\"grep -q '^#include <limits>$' absl/synchronization/internal/graphcycles.cc || sed -i '/#include <algorithm>/a #include <limits>' absl/synchronization/internal/graphcycles.cc\"],\n sha256 = \"acd93f6baaedc4414ebd08b33bebca7c7a46888916101d8c0b8083573526d070\", # SHARED_ABSL_SHA\n strip_prefix = \"abseil-cpp-43ef2148c0936ebf7cb4be6b19927a9d9d145b8f\",\n urls = [\n"""
109+
text = "".join(
110+
difflib.unified_diff(
111+
ORIGINAL_TENSORFLOW_ABSL_PATCH.splitlines(keepends=True),
112+
(ORIGINAL_TENSORFLOW_ABSL_PATCH + ANDROID_GRAPHCYCLES_PATCH_HUNK).splitlines(
113+
keepends=True
114+
),
115+
fromfile="a/third_party/com_google_absl_fix_mac_and_nvcc_build.patch",
116+
tofile="b/third_party/com_google_absl_fix_mac_and_nvcc_build.patch",
117+
)
118+
)
119+
if not text.endswith("\n"):
120+
text += "\n"
60121
path.write_text(text, encoding="utf-8")
61122

62123

0 commit comments

Comments
 (0)