diff --git a/src/mctpd.c b/src/mctpd.c index 6f3eeab..e71e9a8 100644 --- a/src/mctpd.c +++ b/src/mctpd.c @@ -140,6 +140,8 @@ struct link { struct ctx *ctx; }; +struct vdm_type_support; + struct peer { uint32_t net; mctp_eid_t eid; @@ -177,6 +179,9 @@ struct peer { uint8_t *message_types; size_t num_message_types; + struct vdm_type_support *vdm_types; + size_t num_vdm_types; + // From Get Endpoint ID uint8_t endpoint_type; uint8_t medium_spec; @@ -2013,6 +2018,7 @@ static int remove_peer(struct peer *peer) n->peers[peer->eid] = NULL; free(peer->message_types); + free(peer->vdm_types); free(peer->uuid); for (idx = 0; idx < ctx->num_peers; idx++) { @@ -2055,6 +2061,7 @@ static void free_peers(struct ctx *ctx) for (size_t i = 0; i < ctx->num_peers; i++) { struct peer *peer = ctx->peers[i]; free(peer->message_types); + free(peer->vdm_types); free(peer->uuid); free(peer->path); sd_bus_slot_unref(peer->slot_obmc_endpoint); @@ -2512,6 +2519,116 @@ static int query_get_peer_msgtypes(struct peer *peer) return rc; } +static int query_get_peer_vdm_types(struct peer *peer) +{ + struct mctp_ctrl_resp_get_vdm_support *resp = NULL; + struct vdm_type_support *cur_vdm_type, *new_vdm; + struct mctp_ctrl_cmd_get_vdm_support req; + size_t buf_size, expect_size, new_size; + struct sockaddr_mctp_ext addr; + uint8_t *buf = NULL; + uint16_t *cmd_set; + uint8_t iid; + int rc; + + peer->num_vdm_types = 0; + free(peer->vdm_types); + peer->vdm_types = NULL; + + req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_VENDOR_MESSAGE_SUPPORT; + req.vendor_id_set_selector = 0; + + while (req.vendor_id_set_selector != + MCTP_GET_VDM_SUPPORT_NO_MORE_CAP_SET) { + iid = mctp_next_iid(peer->ctx); + + mctp_ctrl_msg_hdr_init_req( + &req.ctrl_hdr, iid, + MCTP_CTRL_CMD_GET_VENDOR_MESSAGE_SUPPORT); + rc = endpoint_query_peer(peer, MCTP_CTRL_HDR_MSG_TYPE, &req, + sizeof(req), &buf, &buf_size, &addr); + if (rc < 0) + goto out; + + /* Check for minimum length of PCIe VDM*/ + expect_size = sizeof(*resp); + rc = mctp_ctrl_validate_response( + buf, buf_size, expect_size, peer_tostr_short(peer), iid, + MCTP_CTRL_CMD_GET_VENDOR_MESSAGE_SUPPORT); + if (rc) + goto out; + + resp = (void *)buf; + if (resp->vendor_id_format != + MCTP_GET_VDM_SUPPORT_PCIE_FORMAT_ID && + resp->vendor_id_format != + MCTP_GET_VDM_SUPPORT_IANA_FORMAT_ID) { + warnx("%s: bad vendor_id_format 0x%02x dest %s", + __func__, resp->vendor_id_format, + peer_tostr(peer)); + rc = -ENOMSG; + goto out; + } + + if (resp->vendor_id_format == + MCTP_GET_VDM_SUPPORT_IANA_FORMAT_ID) { + /* Accomodate 2 bytes for IANA VID */ + expect_size += sizeof(uint16_t); + } + + if (buf_size != expect_size) { + warnx("%s: bad reply length. got %zu, expected %zu dest %s", + __func__, buf_size, expect_size, + peer_tostr(peer)); + rc = -ENOMSG; + goto out; + } + + new_size = (peer->num_vdm_types + 1) * + sizeof(struct vdm_type_support); + new_vdm = realloc(peer->vdm_types, new_size); + if (!new_vdm) { + rc = -ENOMEM; + goto out; + } + peer->vdm_types = new_vdm; + cur_vdm_type = peer->vdm_types + peer->num_vdm_types; + cur_vdm_type->format = resp->vendor_id_format; + + if (resp->vendor_id_format == + MCTP_GET_VDM_SUPPORT_IANA_FORMAT_ID) { + cur_vdm_type->vendor_id.iana = + be32toh(resp->vendor_id_data_iana); + } else { + cur_vdm_type->vendor_id.pcie = + be16toh(resp->vendor_id_data_pcie); + } + // Assume IANA and adjust if PCIE + cmd_set = (uint16_t *)(resp + 1); + if (resp->vendor_id_format == + MCTP_GET_VDM_SUPPORT_PCIE_FORMAT_ID) { + cmd_set--; + } + cur_vdm_type->cmd_set = be16toh(*cmd_set); + peer->num_vdm_types++; + + /* Use the next selector from the response. 0xFF indicates no more entries */ + req.vendor_id_set_selector = resp->vendor_id_set_selector; + free(buf); + buf = NULL; + } + rc = 0; + +out: + free(buf); + if (rc < 0) { + free(peer->vdm_types); + peer->vdm_types = NULL; + peer->num_vdm_types = 0; + } + return rc; +} + static int peer_set_uuid(struct peer *peer, const uint8_t uuid[16]) { if (!peer->uuid) { @@ -2946,6 +3063,7 @@ static int method_learn_endpoint(sd_bus_message *call, void *data, static int query_peer_properties(struct peer *peer) { const unsigned int max_retries = 4; + bool supports_vdm = false; int rc; for (unsigned int i = 0; i < max_retries; i++) { @@ -2974,6 +3092,39 @@ static int query_peer_properties(struct peer *peer) } } + for (unsigned int i = 0; i < peer->num_message_types; i++) { + if (peer->message_types[i] == + MCTP_GET_VDM_SUPPORT_IANA_FORMAT_ID || + peer->message_types[i] == + MCTP_GET_VDM_SUPPORT_PCIE_FORMAT_ID) { + supports_vdm = true; + break; + } + } + + for (unsigned int i = 0; supports_vdm && i < max_retries; i++) { + rc = query_get_peer_vdm_types(peer); + // Success + if (rc == 0) + break; + + // On timeout, retry + if (rc == -ETIMEDOUT) { + if (peer->ctx->verbose) + warnx("Retrying to get vendor message types for %s. Attempt %u", + peer_tostr(peer), i + 1); + rc = 0; + continue; + } + + if (rc < 0) { + warnx("Error getting vendor message types for %s. Ignoring error %d %s", + peer_tostr(peer), rc, strerror(-rc)); + rc = 0; + break; + } + } + for (unsigned int i = 0; i < max_retries; i++) { rc = query_get_peer_uuid(peer); @@ -3866,6 +4017,42 @@ static int bus_endpoint_get_prop(sd_bus *bus, const char *path, rc = sd_bus_message_append_array(reply, 'y', peer->message_types, peer->num_message_types); + } else if (strcmp(property, "VendorDefinedMessageTypes") == 0) { + rc = sd_bus_message_open_container(reply, 'a', "(yvu)"); + if (rc < 0) + return rc; + + for (size_t i = 0; i < peer->num_vdm_types; i++) { + struct vdm_type_support *vdm = &peer->vdm_types[i]; + rc = sd_bus_message_open_container(reply, 'r', "yvu"); + if (rc < 0) + return rc; + + rc = sd_bus_message_append(reply, "y", vdm->format); + if (rc < 0) + return rc; + + if (vdm->format == VID_FORMAT_PCIE) { + rc = sd_bus_message_append(reply, "v", "q", + vdm->vendor_id.pcie); + } else { + rc = sd_bus_message_append(reply, "v", "u", + vdm->vendor_id.iana); + } + if (rc < 0) + return rc; + + rc = sd_bus_message_append(reply, "u", + (uint32_t)vdm->cmd_set); + if (rc < 0) + return rc; + + rc = sd_bus_message_close_container(reply); + if (rc < 0) + return rc; + } + + rc = sd_bus_message_close_container(reply); } else if (strcmp(property, "UUID") == 0 && peer->uuid) { const char *s = dfree(bytes_to_uuid(peer->uuid)); rc = sd_bus_message_append(reply, "s", s); @@ -4055,6 +4242,11 @@ static const sd_bus_vtable bus_endpoint_obmc_vtable[] = { bus_endpoint_get_prop, 0, SD_BUS_VTABLE_PROPERTY_CONST), + SD_BUS_PROPERTY("VendorDefinedMessageTypes", + "a(yvu)", + bus_endpoint_get_prop, + 0, + SD_BUS_VTABLE_PROPERTY_CONST), SD_BUS_VTABLE_END }; diff --git a/tests/test_mctpd.py b/tests/test_mctpd.py index 55bf13f..375ff78 100644 --- a/tests/test_mctpd.py +++ b/tests/test_mctpd.py @@ -596,6 +596,143 @@ async def test_query_message_types(dbus, mctpd): assert ep_types == query_types +""" Test that VendorDefinedMessageTypes property is queried and populated correctly """ +async def test_query_vdm_types(dbus, mctpd): + class VDMEndpoint(Endpoint): + async def handle_mctp_control(self, sock, addr, data): + flags, opcode = data[0:2] + if opcode != 0x06: + return await super().handle_mctp_control(sock, addr, data) + vdm_support = [[0, 0x1234, 0x5678], [1, 0xabcdef12, 0x3456]] + iid = flags & 0x1f + raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext) + hdr = [iid, opcode] + selector = data[2] + if selector >= len(vdm_support): + await sock.send(raddr, bytes(hdr + [0x02])) + return + cur_vdm = vdm_support[selector] + selector = 0xFF if selector == (len(vdm_support) - 1) else selector + 1 + resp = hdr + [0x00, selector, cur_vdm[0]] + if cur_vdm[0] == 0: + resp = resp + list(cur_vdm[1].to_bytes(2, 'big')) + else: + resp = resp + list(cur_vdm[1].to_bytes(4, 'big')) + resp = resp + list(cur_vdm[2].to_bytes(2, 'big')) + await sock.send(raddr, bytes(resp)) + + iface = mctpd.system.interfaces[0] + ep = VDMEndpoint(iface, bytes([0x1e]), eid = 15) + mctpd.network.add_endpoint(ep) + + mctp = await mctpd_mctp_iface_obj(dbus, iface) + (eid, net, path, new) = await mctp.call_learn_endpoint(ep.lladdr) + + assert eid == ep.eid + + ep_obj = await mctpd_mctp_endpoint_common_obj(dbus, path) + + # Query VendorDefinedMessageTypes property + vdm_types = list(await ep_obj.get_vendor_defined_message_types()) + + # Verify we got 2 VDM types + assert len(vdm_types) == 2 + + # Verify first VDM type: PCIe format (0), VID 0x1234, cmd_set 0x5678 + assert vdm_types[0][0] == 0 # format: PCIe + assert vdm_types[0][1].value == 0x1234 # vendor_id (variant containing uint16) + assert vdm_types[0][2] == 0x5678 # cmd_set + + # Verify second VDM type: IANA format (1), VID 0xabcdef12, cmd_set 0x3456 + assert vdm_types[1][0] == 1 # format: IANA + assert vdm_types[1][1].value == 0xabcdef12 # vendor_id (variant containing uint32) + assert vdm_types[1][2] == 0x3456 # cmd_set + +""" Test VDM query with invalid responses """ +async def test_query_vdm_types_invalid(dbus, mctpd): + class InvalidVDMEndpointBase(Endpoint): + async def handle_mctp_control(self, sock, addr, data): + flags, opcode = data[0:2] + if opcode != 0x06: + return await super().handle_mctp_control(sock, addr, data) + iid = flags & 0x1f + raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext) + hdr = [iid, opcode] + selector = data[2] + if selector != 0: + await sock.send(raddr, bytes(hdr + [0x02])) + return + resp = hdr + [0x00, 0xFF] + self.get_invalid_vdm_data() + await sock.send(raddr, bytes(resp)) + + def get_invalid_vdm_data(self): + raise NotImplementedError + + class InvalidPCIeLengthEndpoint(InvalidVDMEndpointBase): + def get_invalid_vdm_data(self): + # Format 0 (PCIe) but send 3 bytes for vendor_id (invalid) + return [0, 0x12, 0x34, 0x56, 0x78, 0x90] + + class InvalidIANALengthEndpoint(InvalidVDMEndpointBase): + def get_invalid_vdm_data(self): + # Format 1 (IANA) but send 3 bytes for vendor_id (invalid) + return [1, 0xab, 0xcd, 0xef, 0x34, 0x56] + + class InvalidFormatEndpoint(InvalidVDMEndpointBase): + def get_invalid_vdm_data(self): + # Format 2 (invalid - only 0 and 1 are valid) + return [2, 0x12, 0x34, 0x56, 0x78] + + class UnsupportedCommandEndpoint(Endpoint): + async def handle_mctp_control(self, sock, addr, data): + flags, opcode = data[0:2] + if opcode != 0x06: + return await super().handle_mctp_control(sock, addr, data) + # Return error completion code: command not supported + iid = flags & 0x1f + raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext) + resp = bytes([iid, opcode, 0x05]) # cc=0x05 (Unsupported command) + await sock.send(raddr, resp) + + iface = mctpd.system.interfaces[0] + mctp = await mctpd_mctp_iface_obj(dbus, iface) + + # Test 1: Invalid PCIe length + ep1 = InvalidPCIeLengthEndpoint(iface, bytes([0x1e]), eid = 15) + mctpd.network.add_endpoint(ep1) + (eid1, net1, path1, new1) = await mctp.call_learn_endpoint(ep1.lladdr) + assert eid1 == ep1.eid + ep_obj1 = await mctpd_mctp_endpoint_common_obj(dbus, path1) + vdm_types1 = list(await ep_obj1.get_vendor_defined_message_types()) + assert len(vdm_types1) == 0 + + # Test 2: Invalid IANA length + ep2 = InvalidIANALengthEndpoint(iface, bytes([0x1f]), eid = 16) + mctpd.network.add_endpoint(ep2) + (eid2, net2, path2, new2) = await mctp.call_learn_endpoint(ep2.lladdr) + assert eid2 == ep2.eid + ep_obj2 = await mctpd_mctp_endpoint_common_obj(dbus, path2) + vdm_types2 = list(await ep_obj2.get_vendor_defined_message_types()) + assert len(vdm_types2) == 0 + + # Test 3: Invalid format type + ep3 = InvalidFormatEndpoint(iface, bytes([0x20]), eid = 17) + mctpd.network.add_endpoint(ep3) + (eid3, net3, path3, new3) = await mctp.call_learn_endpoint(ep3.lladdr) + assert eid3 == ep3.eid + ep_obj3 = await mctpd_mctp_endpoint_common_obj(dbus, path3) + vdm_types3 = list(await ep_obj3.get_vendor_defined_message_types()) + assert len(vdm_types3) == 0 + + # Test 4: Unsupported command error + ep4 = UnsupportedCommandEndpoint(iface, bytes([0x21]), eid = 18) + mctpd.network.add_endpoint(ep4) + (eid4, net4, path4, new4) = await mctp.call_learn_endpoint(ep4.lladdr) + assert eid4 == ep4.eid + ep_obj4 = await mctpd_mctp_endpoint_common_obj(dbus, path4) + vdm_types4 = list(await ep_obj4.get_vendor_defined_message_types()) + assert len(vdm_types4) == 0 + """ Network1.LocalEIDs should reflect locally-assigned EID state """ async def test_network_local_eids_single(dbus, mctpd): iface = mctpd.system.interfaces[0]