Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions auto/cuda_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/* This file is needed to workaround issue with parsing system headers. */
66 changes: 66 additions & 0 deletions auto/cuew.template.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ typedef void* DynamicLibrary;
_LIBRARY_FIND_CHECKED(nvrtc_lib, name)
#define NVRTC_LIBRARY_FIND(name) _LIBRARY_FIND(nvrtc_lib, name)

#define CUDNN_LIBRARY_FIND_CHECKED(name) \
_LIBRARY_FIND_CHECKED(cudnn_lib, name)
#define CUDNN_LIBRARY_FIND(name) _LIBRARY_FIND(cudnn_lib, name)

static DynamicLibrary cuda_lib;
static DynamicLibrary nvrtc_lib;
static DynamicLibrary cudnn_lib;

/* Function definitions. */
%FUNCTION_DEFINITIONS%
Expand Down Expand Up @@ -208,6 +213,60 @@ static int cuewNvrtcInit(void) {
return result;
}

static void cuewExitCudnn(void) {
if (cudnn_lib != NULL) {
/* Ignore errors. */
dynamic_library_close(cudnn_lib);
cudnn_lib = NULL;
}
}

static int cuewCudnnInit(void) {
/* Library paths. */
#ifdef _WIN32
/* Expected in c:/windows/system or similar, no path needed. */
const char *cudnn_paths[] = {"cudnn.dll", NULL};
#elif defined(__APPLE__)
/* Default installation path. */
const char *cudnn_paths[] = {"/usr/local/cuda/lib/libcudnn.dylib", NULL};
#else
const char *cudnn_paths[] = {"libcudnn.so",
# if defined(__x86_64__) || defined(_M_X64)
"/usr/local/cuda/lib64/libcudnn.so",
#else
"/usr/local/cuda/lib/libcudnn.so",
#endif
NULL};
#endif
static int initialized = 0;
static int result = 0;
int error;

if (initialized) {
return result;
}

initialized = 1;

error = atexit(cuewExitCudnn);
if (error) {
result = CUEW_ERROR_ATEXIT_FAILED;
return result;
}

/* Load library. */
cudnn_lib = dynamic_library_open_find(cudnn_paths);

if (cudnn_lib == NULL) {
result = CUEW_ERROR_OPEN_FAILED;
return result;
}

%LIB_FIND_CUDNN%

result = CUEW_SUCCESS;
return result;
}

int cuewInit(cuuint32_t flags) {
int result = CUEW_SUCCESS;
Expand All @@ -226,6 +285,13 @@ int cuewInit(cuuint32_t flags) {
}
}

if (flags & CUEW_INIT_CUDNN) {
result = cuewCudnnInit();
if (result != CUEW_SUCCESS) {
return result;
}
}

return result;
}

Expand Down
14 changes: 11 additions & 3 deletions auto/cuew.template.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ typedef unsigned long long CUdeviceptr;
typedef unsigned int CUdeviceptr;
#endif


#ifdef _WIN32
# define CUDAAPI __stdcall
# define CUDA_CB __stdcall
Expand All @@ -60,6 +59,14 @@ typedef unsigned int CUdeviceptr;
# define CUDA_CB
#endif

#if !defined(__CUDACC__)
# define __device_builtin__
#else
# define __device_builtin__ __location__(device_builtin)
#endif

typedef __device_builtin__ struct CUstream_st *cudaStream_t;

%TYPEDEFS%


Expand All @@ -78,8 +85,9 @@ enum {
};

enum {
CUEW_INIT_CUDA = 1,
CUEW_INIT_NVRTC = 2
CUEW_INIT_CUDA = (1 << 0),
CUEW_INIT_NVRTC = (1 << 1),
CUEW_INIT_CUDNN = (1 << 2),
};

int cuewInit(cuuint32_t flags);
Expand Down
33 changes: 27 additions & 6 deletions auto/cuew_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from subprocess import Popen, PIPE

INCLUDE_DIR = "/usr/include"
FILES = ["cuda.h", "cudaGL.h", 'nvrtc.h']
FILES = ["cuda.h", "cudaGL.h", 'nvrtc.h', 'cudnn.h']

TYPEDEFS = []
FUNC_TYPEDEFS = []
Expand Down Expand Up @@ -113,7 +113,10 @@ def _stringify_param(self, param):
# TODO(sergey): Workaround to deal with the
# preprocessed file where array size got
# substituded.
dim = param_type.dim.value
if param_type.dim:
dim = param_type.dim.value
else:
dim = ""
if param.name == "reserved" and dim == "64":
dim = "CU_IPC_HANDLE_SIZE"
result += '[' + dim + ']'
Expand Down Expand Up @@ -196,7 +199,10 @@ def visit_Typedef(self, node):
self.indent += 1
struct = self._stringify_struct(node.type.type)
self.indent -= 1
typedef = quals + type + " {\n" + struct + "} " + node.name
if node.type.type.name:
typedef = quals + type + " {\n" + struct + "} " + node.name
else:
typedef = quals + "struct {\n" + struct + "} " + node.name
complex = True
elif isinstance(node.type.type, c_ast.Enum):
self.indent += 1
Expand Down Expand Up @@ -236,6 +242,8 @@ def preprocess_file(filename, cpp_path):
args.append("-DCUDA_ENABLE_DEPRECATED=1 ")
if filename.endswith("GL.h"):
args.append("-DCUDAAPI= ")
if filename.endswith("cudnn.h"):
args.append("-DCUDNNWINAPI= ")
args.append(filename)

try:
Expand Down Expand Up @@ -267,13 +275,17 @@ def parse_files():
"CUdevice": "void *",
"CUcontext": "void *",
"CUdeviceptr": "void *",
"CUstream": "void *"
"CUstream": "void *",
}

text = "typedef int GLint;\n" + text
text = "typedef unsigned int GLuint;\n" + text
text = "typedef unsigned int GLenum;\n" + text
text = "typedef long size_t;\n" + text
elif filepath.endswith("cudnn.h"):
dummy_typedefs = {
"cudaStream_t": "void *",
}

for typedef in sorted(dummy_typedefs):
text = "typedef " + dummy_typedefs[typedef] + " " + \
Expand All @@ -290,10 +302,13 @@ def parse_files():
if token[0] not in ("__cuda_cuda_h__",
"CUDA_CB",
"CUDAAPI",
"CUDNNWINAPI",
"CUDAGL_H",
"__NVRTC_H__",
"CUDA_ENABLE_DEPRECATED",
"__CUDA_DEPRECATED"):
"__CUDA_DEPRECATED",
"CUDNN_H_",
"__NVRTC_H__"):
DEFINES.append(token)

for line in lines:
Expand Down Expand Up @@ -374,7 +389,7 @@ def print_implementation():
lib_find_cuda = ''
for symbol in SYMBOLS:
if symbol:
if not symbol.startswith('nvrtc'):
if not symbol.startswith('nvrtc') and not symbol.startswith('cudnn'):
lib_find_cuda += " CUDA_LIBRARY_FIND(%s);\n" % (symbol)
else:
lib_find_cuda += "\n"
Expand All @@ -384,10 +399,16 @@ def print_implementation():
if symbol and symbol.startswith('nvrtc'):
lib_find_nvrtc += " NVRTC_LIBRARY_FIND(%s);\n" % (symbol)

lib_find_cudnn = ''
for symbol in SYMBOLS:
if symbol and symbol.startswith('cudnn'):
lib_find_cudnn += " CUDNN_LIBRARY_FIND(%s);\n" % (symbol)

source = source.replace('%FUNCTION_DEFINITIONS%', function_definitions.rstrip())
source = source.replace('%CUDA_ERRORS%', cuda_errors.rstrip())
source = source.replace('%LIB_FIND_CUDA%', lib_find_cuda.rstrip())
source = source.replace('%LIB_FIND_NVRTC%', lib_find_nvrtc.rstrip())
source = source.replace('%LIB_FIND_CUDNN%', lib_find_cudnn.rstrip())

sys.stdout.write(source)

Expand Down
3 changes: 3 additions & 0 deletions auto/driver_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/* This file is needed to workaround issue with parsing system headers. */

typedef long size_t;
11 changes: 11 additions & 0 deletions cuewTest/cuewTest.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,16 @@ int main(int argc, char* argv[]) {
printf("NVRTC not found\n");
}

if (cuewInit(CUEW_INIT_CUDNN) == CUEW_SUCCESS) {
printf("CUDNN found\n");
size_t version = cudnnGetVersion();
printf("Found Deep Neural Network library version %d.%d\n",
version / 1000,
version % 1000);
}
else {
printf("CUDNN not found\n");
}

return EXIT_SUCCESS;
}
Loading