Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/scriptworker/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Context(object):

config: Optional[Dict[str, Any]] = None
credentials_timestamp: Optional[int] = None
credentials_fd: int = -1
proc: Optional[task_process.TaskProcess] = None
queue: Optional[Queue] = None
session: Optional[aiohttp.ClientSession] = None
Expand Down Expand Up @@ -98,13 +99,17 @@ def claim_task(self, claim_task: Optional[Dict[str, Any]]) -> None:
if claim_task:
self.task = claim_task["task"]
self.verify_task()
# flags=0 to let the child inherit this fd
self.credentials_fd = os.memfd_create("scriptworker_temp_creds", flags=0)
self.temp_credentials = claim_task["credentials"]
path = os.path.join(self.config["work_dir"], "task.json")
assert self.task
self.write_json(path, self.task, "Writing task file to {path}...")
else:
self.temp_credentials = None
self.task = None
os.close(self.credentials_fd)
self.credentials_fd = -1

def verify_task(self) -> None:
"""Run some task sanity checks on ``self.task``."""
Expand Down Expand Up @@ -193,6 +198,10 @@ def temp_credentials(self) -> Optional[Dict[str, Any]]:
def temp_credentials(self, credentials: Optional[Dict[str, Any]]) -> None:
self._temp_credentials = credentials
self.temp_queue = self.create_queue(self.temp_credentials)
if credentials:
data = json.dumps(credentials, indent=2, sort_keys=True).encode("ascii")
# use pwrite so we don't confuse the child by changing the file offset
assert os.pwrite(self.credentials_fd, data, 0) == len(data)

def write_json(self, path: str, contents: Dict[str, Any], message: str) -> None:
"""Write json to disk.
Expand Down
11 changes: 10 additions & 1 deletion src/scriptworker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,16 @@ async def run_task(context, to_cancellable_process):
env["TASK_ID"] = context.task_id or "None"
env["RUN_ID"] = str(get_run_id(context.claim_task))
env["TASKCLUSTER_ROOT_URL"] = context.config["taskcluster_root_url"]
kwargs = {"stdout": PIPE, "stderr": PIPE, "stdin": None, "close_fds": True, "preexec_fn": lambda: os.setsid(), "env": env} # pragma: no branch
env["TASKCLUSTER_CREDENTIALS_FD"] = str(context.credentials_fd)
kwargs = {
"stdout": PIPE,
"stderr": PIPE,
"stdin": None,
"close_fds": True,
"preexec_fn": lambda: os.setsid(),
"env": env,
"pass_fds": (context.credentials_fd,),
} # pragma: no branch

timeout = get_task_maxruntime(context.task, context.config["task_max_timeout"])
subprocess = await asyncio.create_subprocess_exec(*context.config["task_script"], **kwargs)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,46 @@ async def test_set_reset_task(rw_context, claim_task, reclaim_task):
assert rw_context.temp_queue is None


def test_credentials_fd_initial(rw_context):
assert rw_context.credentials_fd == -1


@pytest.mark.asyncio
async def test_credentials_fd_opened_on_claim_task(rw_context, claim_task):
rw_context.claim_task = claim_task
assert rw_context.credentials_fd >= 0
os.fstat(rw_context.credentials_fd) # raises OSError if fd is invalid


@pytest.mark.asyncio
async def test_credentials_fd_content(rw_context, claim_task):
rw_context.claim_task = claim_task
fd = rw_context.credentials_fd
size = os.fstat(fd).st_size
data = os.pread(fd, size, 0)
assert json.loads(data) == claim_task["credentials"]


@pytest.mark.asyncio
async def test_credentials_fd_updated_on_reclaim(rw_context, claim_task, reclaim_task):
rw_context.claim_task = claim_task
rw_context.reclaim_task = reclaim_task
fd = rw_context.credentials_fd
size = os.fstat(fd).st_size
data = os.pread(fd, size, 0)
assert json.loads(data) == reclaim_task["credentials"]


@pytest.mark.asyncio
async def test_credentials_fd_closed_on_reset(rw_context, claim_task):
rw_context.claim_task = claim_task
fd = rw_context.credentials_fd
rw_context.claim_task = None
assert rw_context.credentials_fd == -1
with pytest.raises(OSError):
os.fstat(fd)


@pytest.mark.asyncio
async def test_projects(rw_context, mocker):
fake_projects = {"mozilla-central": "blah", "count": 0}
Expand Down
18 changes: 18 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,24 @@ async def test_run_task_timeout(context):
assert context.proc is None


@pytest.mark.asyncio
async def test_run_task_credentials_fd(context):
"""The subprocess receives the temp credentials via the fd in TASKCLUSTER_CREDENTIALS_FD."""
context.config["task_script"] = (
sys.executable,
"-c",
"import json, os, sys; "
"fd = int(os.environ['TASKCLUSTER_CREDENTIALS_FD']); "
"size = os.fstat(fd).st_size; "
"sys.stdout.write(os.pread(fd, size, 0).decode('ascii'))",
)
await swtask.run_task(context, noop_to_cancellable_process)
log_file = log.get_log_filename(context)
contents = read(log_file)
parsed, _ = json.JSONDecoder().raw_decode(contents)
assert parsed == context.temp_credentials


# report* {{{1
@pytest.mark.asyncio
async def test_reportCompleted(context, successful_queue):
Expand Down