Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 170 additions & 23 deletions crates/cloud-sdk/src/applications/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,61 @@ use chrono::{DateTime, Utc};
use derive_builder::Builder;
use futures::Stream;
use reqwest::header::HeaderValue;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use serde_json;
use std::{collections::HashMap, pin::Pin};
use std::{collections::HashMap, fmt::Display, pin::Pin};
use uuid::Uuid;

use crate::error::SdkError;

/// A custom DateTime<Utc> type that handles RFC3339 timestamps with missing 'Z' timezone indicator.
/// When deserializing, if the timestamp doesn't end with 'Z', it's automatically appended.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
#[serde(transparent)]
pub struct Rfc3339DateTime(DateTime<Utc>);

impl Rfc3339DateTime {
pub fn now() -> Self {
Self(Utc::now())
}
}

impl From<DateTime<Utc>> for Rfc3339DateTime {
fn from(value: DateTime<Utc>) -> Self {
Self(value)
}
}

impl Display for Rfc3339DateTime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.to_rfc3339())
}
}

impl<'de> Deserialize<'de> for Rfc3339DateTime {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let mut s = String::deserialize(deserializer)?;
if !s.ends_with("Z") && !s.ends_with("+00:00") {
s.push('Z');
}

DateTime::parse_from_rfc3339(&s)
.map(|dt| Rfc3339DateTime(dt.with_timezone(&Utc)))
.map_err(serde::de::Error::custom)
}
}

impl std::ops::Deref for Rfc3339DateTime {
type Target = DateTime<Utc>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize, Builder)]
pub struct ApplicationManifest {
#[builder(setter(into))]
Expand Down Expand Up @@ -644,7 +692,7 @@ pub struct RequestProgressUpdated {
#[serde(default)]
pub attributes: Option<serde_json::Value>,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for RequestProgressUpdated {
Expand All @@ -665,11 +713,11 @@ impl RequestEventMetadata for RequestProgressUpdated {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand All @@ -682,7 +730,7 @@ pub struct RequestFinishedEvent {
#[serde(default)]
pub outcome: RequestOutcome,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for RequestFinishedEvent {
Expand All @@ -703,11 +751,11 @@ impl RequestEventMetadata for RequestFinishedEvent {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand All @@ -718,7 +766,7 @@ pub struct RequestStartedEvent {
pub application_version: String,
pub request_id: String,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for RequestStartedEvent {
Expand All @@ -739,11 +787,11 @@ impl RequestEventMetadata for RequestStartedEvent {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand All @@ -756,7 +804,7 @@ pub struct FunctionRunCreated {
pub function_name: String,
pub function_run_id: String,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for FunctionRunCreated {
Expand All @@ -777,11 +825,11 @@ impl RequestEventMetadata for FunctionRunCreated {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand All @@ -796,7 +844,7 @@ pub struct FunctionRunAssigned {
pub allocation_id: String,
pub executor_id: String,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for FunctionRunAssigned {
Expand All @@ -817,11 +865,11 @@ impl RequestEventMetadata for FunctionRunAssigned {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand All @@ -844,7 +892,7 @@ pub struct FunctionRunCompleted {
pub allocation_id: String,
pub outcome: FunctionRunOutcomeSummary,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for FunctionRunCompleted {
Expand All @@ -865,11 +913,11 @@ impl RequestEventMetadata for FunctionRunCompleted {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand All @@ -882,7 +930,7 @@ pub struct FunctionRunMatchedCache {
pub function_name: String,
pub function_run_id: String,
#[serde(default)]
pub created_at: Option<DateTime<Utc>>,
pub created_at: Option<Rfc3339DateTime>,
}

impl RequestEventMetadata for FunctionRunMatchedCache {
Expand All @@ -903,11 +951,11 @@ impl RequestEventMetadata for FunctionRunMatchedCache {
}

fn created_at(&self) -> Option<&DateTime<Utc>> {
self.created_at.as_ref()
self.created_at.as_ref().map(|rfc| &rfc.0)
}

fn set_created_at(&mut self, date: DateTime<Utc>) {
self.created_at = Some(date);
self.created_at = Some(Rfc3339DateTime(date));
}
}

Expand Down Expand Up @@ -1227,3 +1275,102 @@ pub struct ProgressUpdatesJson {
pub updates: Vec<RequestStateChangeEvent>,
pub next_token: Option<String>,
}

#[cfg(test)]
mod tests {
use super::*;
use chrono::Datelike;
use serde_json::json;

#[test]
fn test_rfc3339_datetime_with_z() {
let json = json!("2024-01-15T10:30:45Z");
let result: Result<Rfc3339DateTime, _> = serde_json::from_value(json);
assert!(result.is_ok());
}

#[test]
fn test_rfc3339_datetime_without_z() {
let json = json!("2024-01-15T10:30:45");
let result: Result<Rfc3339DateTime, _> = serde_json::from_value(json);
assert!(result.is_ok());
let dt = result.unwrap();
// Verify it was parsed correctly as UTC
assert_eq!(dt.0.year(), 2024);
assert_eq!(dt.0.month(), 1);
assert_eq!(dt.0.day(), 15);
}

#[test]
fn test_rfc3339_datetime_with_timezone_offset() {
let json = json!("2024-01-15T10:30:45+00:00");
let result: Result<Rfc3339DateTime, _> = serde_json::from_value(json);
assert!(result.is_ok());
}

#[test]
fn test_request_started_event_deserialization() {
let json = json!({
"namespace": "test",
"application_name": "app",
"application_version": "1.0",
"request_id": "req-123",
"created_at": "2024-01-15T10:30:45"
});
let result: Result<RequestStartedEvent, _> = serde_json::from_value(json);
assert!(result.is_ok());
let event = result.unwrap();
assert!(event.created_at.is_some());
}

#[test]
fn test_rfc3339_datetime_serialization() {
// Test that serializing Rfc3339DateTime produces a plain string, not a nested struct
let now = chrono::Utc::now();
let rfc_dt = Rfc3339DateTime(now);
let serialized = serde_json::to_value(&rfc_dt).unwrap();

// Should be a string, not an object
assert!(
serialized.is_string(),
"Expected serialized DateTime to be a string, got: {:?}",
serialized
);

// Should contain 'Z' at the end
let date_str = serialized.as_str().unwrap();
assert!(
date_str.ends_with('Z'),
"Expected 'Z' at end of serialized DateTime"
);
}

#[test]
fn test_request_started_event_serialization() {
// Test that serializing an event doesn't nest the created_at field
let event = RequestStartedEvent {
namespace: "test".to_string(),
application_name: "app".to_string(),
application_version: "1.0".to_string(),
request_id: "req-123".to_string(),
created_at: Some(Rfc3339DateTime(Utc::now())),
};

let serialized = serde_json::to_value(&event).unwrap();
let obj = serialized.as_object().unwrap();

// created_at should be a string directly, not an object
let created_at = &obj["created_at"];
assert!(
created_at.is_string(),
"Expected created_at to be a string, got: {:?}",
created_at
);

let date_str = created_at.as_str().unwrap();
assert!(
date_str.ends_with('Z'),
"Expected 'Z' at end of created_at value"
);
}
}
4 changes: 2 additions & 2 deletions crates/cloud-sdk/src/images/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
//! .application_name("my-app")
//! .application_version("1.0.0")
//! .function_name("main")
//! .sdk_version("0.3.12")
//! .sdk_version("0.2")
//! .build().unwrap();
//!
//! images_client.build_image(build_request);
Expand Down Expand Up @@ -113,7 +113,7 @@ impl ImagesClient {
/// .application_name("my-app")
/// .application_version("1.0.0")
/// .function_name("main")
/// .sdk_version("0.2.75")
/// .sdk_version("0.2")
/// .build()?;
///
/// images_client.build_image(request).await?;
Expand Down
11 changes: 10 additions & 1 deletion crates/cloud-sdk/src/images/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,16 @@ impl Image {
lines.push(render_build_operation(op));
}

lines.push(format!("RUN pip install tensorlake=={}", sdk_version));
if sdk_version.starts_with("~=")
|| sdk_version.starts_with(">=")
|| sdk_version.starts_with("<=")
|| sdk_version.starts_with("!=")
|| sdk_version.starts_with("==")
{
lines.push(format!("RUN pip install tensorlake{}", sdk_version));
} else {
lines.push(format!("RUN pip install tensorlake=={}", sdk_version));
}

lines.join("\n")
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cloud-sdk/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub async fn build_test_image(
.application_name(application_name)
.application_version(application_version)
.function_name(func_name)
.sdk_version("0.2.75")
.sdk_version("~=0.2")
.build()
.unwrap();

Expand Down