Skip to content

Commit e67fccd

Browse files
Adds support for secure channels (#18)
Signed-off-by: Elena Kolevska <elena@kolevska.com>
1 parent 6dd9ac1 commit e67fccd

File tree

5 files changed

+58
-6
lines changed

5 files changed

+58
-6
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,7 @@ dmypy.json
128128
# Pyre type checker
129129
.pyre/
130130

131+
# IDEs
132+
.idea
133+
131134
coverage.lcov

durabletask/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def __init__(self, *,
9595
host_address: Union[str, None] = None,
9696
metadata: Union[List[Tuple[str, str]], None] = None,
9797
log_handler = None,
98-
log_formatter: Union[logging.Formatter, None] = None):
99-
channel = shared.get_grpc_channel(host_address, metadata)
98+
log_formatter: Union[logging.Formatter, None] = None,
99+
secure_channel: bool = False):
100+
channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel)
100101
self._stub = stubs.TaskHubSidecarServiceStub(channel)
101102
self._logger = shared.get_logger("client", log_handler, log_formatter)
102103

durabletask/internal/shared.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@ def get_default_host_address() -> str:
2020
return "localhost:4001"
2121

2222

23-
def get_grpc_channel(host_address: Union[str, None], metadata: Union[List[Tuple[str, str]], None]) -> grpc.Channel:
23+
def get_grpc_channel(host_address: Union[str, None], metadata: Union[List[Tuple[str, str]], None], secure_channel: bool = False) -> grpc.Channel:
2424
if host_address is None:
2525
host_address = get_default_host_address()
26-
channel = grpc.insecure_channel(host_address)
26+
27+
if secure_channel:
28+
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
29+
else:
30+
channel = grpc.insecure_channel(host_address)
31+
2732
if metadata is not None and len(metadata) > 0:
2833
interceptors = [DefaultClientInterceptorImpl(metadata)]
2934
channel = grpc.intercept_channel(channel, *interceptors)

durabletask/worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,16 @@ def __init__(self, *,
8787
host_address: Union[str, None] = None,
8888
metadata: Union[List[Tuple[str, str]], None] = None,
8989
log_handler = None,
90-
log_formatter: Union[logging.Formatter, None] = None):
90+
log_formatter: Union[logging.Formatter, None] = None,
91+
secure_channel: bool = False):
9192
self._registry = _Registry()
9293
self._host_address = host_address if host_address else shared.get_default_host_address()
9394
self._metadata = metadata
9495
self._logger = shared.get_logger("worker", log_handler, log_formatter)
9596
self._shutdown = Event()
9697
self._response_stream = None
9798
self._is_running = False
99+
self._secure_channel = secure_channel
98100

99101
def __enter__(self):
100102
return self
@@ -116,7 +118,7 @@ def add_activity(self, fn: task.Activity) -> str:
116118

117119
def start(self):
118120
"""Starts the worker on a background thread and begins listening for work items."""
119-
channel = shared.get_grpc_channel(self._host_address, self._metadata)
121+
channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel)
120122
stub = stubs.TaskHubSidecarServiceStub(channel)
121123

122124
if self._is_running:

tests/test_client.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from unittest.mock import patch
2+
3+
from durabletask.internal.shared import (DefaultClientInterceptorImpl,
4+
get_default_host_address,
5+
get_grpc_channel)
6+
7+
HOST_ADDRESS = 'localhost:50051'
8+
METADATA = [('key1', 'value1'), ('key2', 'value2')]
9+
10+
11+
def test_get_grpc_channel_insecure():
12+
with patch('grpc.insecure_channel') as mock_channel:
13+
get_grpc_channel(HOST_ADDRESS, METADATA, False)
14+
mock_channel.assert_called_once_with(HOST_ADDRESS)
15+
16+
17+
def test_get_grpc_channel_secure():
18+
with patch('grpc.secure_channel') as mock_channel, patch(
19+
'grpc.ssl_channel_credentials') as mock_credentials:
20+
get_grpc_channel(HOST_ADDRESS, METADATA, True)
21+
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
22+
23+
24+
def test_get_grpc_channel_default_host_address():
25+
with patch('grpc.insecure_channel') as mock_channel:
26+
get_grpc_channel(None, METADATA, False)
27+
mock_channel.assert_called_once_with(get_default_host_address())
28+
29+
30+
def test_get_grpc_channel_with_metadata():
31+
with patch('grpc.insecure_channel') as mock_channel, patch(
32+
'grpc.intercept_channel') as mock_intercept_channel:
33+
get_grpc_channel(HOST_ADDRESS, METADATA, False)
34+
mock_channel.assert_called_once_with(HOST_ADDRESS)
35+
mock_intercept_channel.assert_called_once()
36+
37+
# Capture and check the arguments passed to intercept_channel()
38+
args, kwargs = mock_intercept_channel.call_args
39+
assert args[0] == mock_channel.return_value
40+
assert isinstance(args[1], DefaultClientInterceptorImpl)
41+
assert args[1]._metadata == METADATA

0 commit comments

Comments
 (0)