Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ bin
*.a
CMakeCache.txt
cmake-build-debug
bazel-*
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,36 @@ on Windows follow any installation instructions and install libmicrohttpd, curl
Finally, call
> bazel build //...

If using a sufficiently new compiler, boringssl dependency for gRPC may fail to build. Try:
> bazel build --copt=-Wno-error=array-parameter --copt=-Wno-error=stringop-overflow //...

To run tests, call
> bazel test //...

## Training

Run:
> bazel run //src/training/train
> bazel run //n2p/training:train
Don't forget about `--copt` from above if boringssl fails to build.

To get options for training, use:
> bazel run //src/training/train --help
> bazel run //n2p/training:train --help

By default, train gets input programs (converted to JSON for example with UnuglifyJS) from the file testdata in the current directory. As a result, it creates files with the trained model.

If you wish to train the model using pseudolikelihood use the following parameters:

> bazel run //src/training/train -- -training_method pl -input path/to/input/file --logtostderr
> bazel run //n2p/training:train -- -training_method pl -input path/to/input/file --logtostderr

you can control the pseudolikelihood specific beam size with the `-beam_size` parameter which is different from the beam size used during MAP Inference.

`//src/training/train` expects data to be in protobuf recordIO format. If you want to use JSON input - use `//src/training/train_json` instead.
`//n2p/training:train` expects data to be in protobuf recordIO format. If you want to use JSON input - use `//n2p/training:train_json` instead.

### Factors

by default the usage of factor features in Nice2Predict is enabled, however if you wish to disable it you can launch the training with the following command:

> bazel run //src/training/train -- -use_factors=false -input path/to/input/file --logtostderr
> bazel run //n2p/training:train -- -use_factors=false -input path/to/input/file --logtostderr

## Predicting properties

Expand Down
61 changes: 45 additions & 16 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
git_repository(
name = "org_pubref_rules_protobuf",
remote = "https://github.com/pubref/rules_protobuf",
tag = "v0.8.1",
# commit = "d9523f3d443b6a4f3fabc72051d84eb5474d7745"
)

load("@org_pubref_rules_protobuf//cpp:rules.bzl", "cpp_proto_repositories")
cpp_proto_repositories()
load("//tools/build_defs:externals.bzl", "new_patched_http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

#BTW, @org_pubref_rules_protobuf already contains @com_google_googletest

load("//tools/build_defs:externals.bzl",
"new_patched_http_archive",
)

# The sparsehash BUILD is copied from https://github.com/livegrep/livegrep
new_patched_http_archive(
name = "com_github_sparsehash",
url = "https://github.com/sparsehash/sparsehash/archive/sparsehash-2.0.3.tar.gz",
Expand All @@ -23,3 +10,45 @@ new_patched_http_archive(
strip_prefix = "sparsehash-sparsehash-2.0.3/",
patch_file = "//third_party:sparsehash.patch",
)

http_archive(
name = "com_google_googletest",
url = "https://github.com/google/googletest/archive/refs/tags/release-1.11.0.tar.gz",
strip_prefix = "googletest-release-1.11.0",
sha256 = "b4870bf121ff7795ba20d20bcdd8627b8e088f2d1dab299a031c1034eddc93d5",
)

git_repository(
name = "com_google_googletest",
remote = "https://github.com/google/googletest.git",
tag = "release-1.11.0",
)

http_archive(
name = "com_google_protobuf",
url = "https://github.com/protocolbuffers/protobuf/releases/download/v3.17.3/protobuf-cpp-3.17.3.tar.gz",
strip_prefix = "protobuf-3.17.3",
sha256 = "51cec99f108b83422b7af1170afd7aeb2dd77d2bcbb7b6bad1f92509e9ccf8cb",
)

load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
protobuf_deps()

http_archive(
name = "com_github_madler_zlib",
url = "https://github.com/madler/zlib/archive/refs/tags/v1.2.11.tar.gz",
strip_prefix = "zlib-1.2.11",
build_file = "//third_party:BUILD.zlib",
sha256 = "629380c90a77b964d896ed37163f5c3a34f6e6d897311f1df2a7016355c45eff",
)

http_archive(
name = "com_github_grpc_grpc",
url = "https://github.com/grpc/grpc/archive/refs/tags/v1.38.1.tar.gz",
strip_prefix = "grpc-1.38.1",
sha256 = "f60e5b112913bf776a22c16a3053cc02cf55e60bf27a959fd54d7aaf8e2da6e8",
)
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
grpc_extra_deps()
2 changes: 1 addition & 1 deletion json/server_connectors_httpserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void HttpServer::SetUrlHandler(const string &url, IClientConnectionHandler *hand
this->SetHandler(NULL);
}

int HttpServer::callback(void *cls, MHD_Connection *connection, const char *url, const char *method, const char *version, const char *upload_data, size_t *upload_data_size, void **con_cls)
MHD_Result HttpServer::callback(void *cls, MHD_Connection *connection, const char *url, const char *method, const char *version, const char *upload_data, size_t *upload_data_size, void **con_cls)
{
(void)version;
if (*con_cls == NULL)
Expand Down
2 changes: 1 addition & 1 deletion json/server_connectors_httpserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace jsonrpc

std::map<std::string, IClientConnectionHandler*> urlhandler;

static int callback(void *cls, struct MHD_Connection *connection, const char *url, const char *method, const char *version, const char *upload_data, size_t *upload_data_size, void **con_cls);
static MHD_Result callback(void *cls, struct MHD_Connection *connection, const char *url, const char *method, const char *version, const char *upload_data, size_t *upload_data_size, void **con_cls);

IClientConnectionHandler* GetHandler(const std::string &url);

Expand Down
34 changes: 28 additions & 6 deletions n2p/protos/BUILD
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
load("@org_pubref_rules_protobuf//cpp:rules.bzl", "cc_proto_library")
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc")

proto_library(
name = "service",
srcs = ["service.proto"],
deps = [
":interface",
"@com_google_protobuf//:any_proto",
],
)

proto_library(
name = "interface",
srcs = ["interface.proto"],
)

cc_proto_library(
name = "service_cc",
deps = [":service"],
)

cc_grpc_library(
name = "service_cc_proto",
protos = ["service.proto"],
proto_deps = ["interface_cc_proto"],
with_grpc = True,
srcs = [":service"],
deps = [
":service_cc",
"@com_github_grpc_grpc//:grpc++",
],
grpc_only = True,
visibility = ["//visibility:public"],
)

cc_proto_library(
name = "interface_cc_proto",
protos = ["interface.proto"],
with_grpc = False,
deps = [":interface"],
visibility = ["//visibility:public"]
)
2 changes: 1 addition & 1 deletion n2p/server/nice2server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <gflags/gflags.h>
#include <glog/logging.h>

#include "grpc++/grpc++.h"
#include "grpcpp/grpcpp.h"

#include "base/stringprintf.h"
#include "n2p/inference/graph_inference.h"
Expand Down
48 changes: 48 additions & 0 deletions third_party/BUILD.zlib
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ****************************************************************
# BUILD file for https://github.com/madler/zlib
# ****************************************************************
#
package(default_visibility = ["//visibility:public"])

licenses(["notice"]) # BSD/MIT-like license (for zlib)

cc_library(
name = "zlib",
srcs = [
"adler32.c",
"compress.c",
"crc32.c",
"deflate.c",
"gzclose.c",
"gzlib.c",
"gzread.c",
"gzwrite.c",
"infback.c",
"inffast.c",
"inflate.c",
"inftrees.c",
"trees.c",
"uncompr.c",
"zutil.c",
],
hdrs = [
"crc32.h",
"deflate.h",
"gzguts.h",
"inffast.h",
"inffixed.h",
"inflate.h",
"inftrees.h",
"trees.h",
"zconf.h",
"zlib.h",
"zutil.h",
],
includes = [
".",
],
copts = [
"-Wno-unused-variable",
"-Wno-implicit-function-declaration",
],
)