Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,35 @@ def get_lance_connection():

def serialize_arrow_value(value):
try:
# Handle vector columns with special processing
if pa.types.is_list(value.type) and pa.types.is_floating(value.value_type):
# Stop immediately if the Arrow scalar is null
if value is None or not getattr(value, "is_valid", True):
return None

# 1. Handle Vector columns (Top-level OR nested)
if (pa.types.is_list(value.type) or pa.types.is_fixed_size_list(value.type)) and getattr(value.type, "value_type", None) and pa.types.is_floating(value.type.value_type):
try:
vec = value.as_py()
if vec is None:
return None

# Validate vector data
if not isinstance(vec, (list, tuple)) or len(vec) == 0:
return {"type": "vector", "error": "Invalid vector data"}

# Check for valid numeric values
valid_values = []
for v in vec:
if v is not None and isinstance(v, (int, float)) and not (isinstance(v, float) and (v != v or v == float('inf') or v == float('-inf'))):
valid_values.append(float(v))
else:
valid_values.append(0.0) # Replace invalid values with 0
valid_values.append(0.0)

if not valid_values:
return {"type": "vector", "error": "No valid numeric values in vector"}

# Calculate vector statistics
norm = float(sum(x*x for x in valid_values) ** 0.5) if valid_values else 0.0
vec_min = float(min(valid_values)) if valid_values else 0.0
vec_max = float(max(valid_values)) if valid_values else 0.0
vec_mean = float(sum(valid_values) / len(valid_values)) if valid_values else 0.0

# Special handling for CLIP vectors (typically 512 dimensions)
is_clip_vector = len(valid_values) == 512

result = {
Expand All @@ -93,25 +93,38 @@ def serialize_arrow_value(value):
"min": vec_min,
"max": vec_max,
"mean": vec_mean,
"preview": valid_values[:32], # Show first 32 values
"preview": valid_values[:32],
}

if is_clip_vector:
result["model"] = "likely_clip"
result["description"] = "512-dimensional CLIP embedding"
# For CLIP vectors, show some key statistics
result["stats"] = {
"normalized": abs(norm - 1.0) < 0.01, # CLIP vectors are typically normalized
"normalized": abs(norm - 1.0) < 0.01,
"sparsity": sum(1 for x in valid_values if abs(x) < 0.01) / len(valid_values),
"positive_ratio": sum(1 for x in valid_values if x > 0) / len(valid_values)
}

return result
except Exception as vec_error:
logger.warning(f"Error processing vector data: {vec_error}")
return {"type": "vector", "error": f"Vector processing failed: {str(vec_error)}"}

# Use the general serialize_value utility for all other types
# 2. Handle Structs recursively to catch vectors hidden inside objects
if pa.types.is_struct(value.type):
result = {}
for field in value.type:
# In PyArrow, value[field.name] fetches the nested pa.Scalar
result[field.name] = serialize_arrow_value(value[field.name])
return result

# 3. Handle Lists recursively (e.g., Arrays of Structs containing Vectors)
if pa.types.is_list(value.type) or pa.types.is_large_list(value.type) or pa.types.is_fixed_size_list(value.type):
result = []
for item in value: # Iterating a PyArrow ListScalar yields nested pa.Scalars
result.append(serialize_arrow_value(item))
return result

# 4. Fallback to normal serialization for strings, ints, dates, etc.
return serialize_value(value)
except Exception as e:
logger.warning(f"Error serializing value: {e}")
Expand Down Expand Up @@ -178,7 +191,7 @@ async def get_dataset_schema(dataset_name: str):
"nullable": field.nullable
}

if pa.types.is_list(field.type) and pa.types.is_floating(field.type.value_type):
if (pa.types.is_list(field.type) or pa.types.is_fixed_size_list(field.type)) and pa.types.is_floating(field.type.value_type):
field_info["vector_dim"] = None

schema_dict["fields"].append(field_info)
Expand Down Expand Up @@ -207,7 +220,7 @@ async def get_dataset_columns(dataset_name: str):
"nullable": field.nullable
}

if pa.types.is_list(field.type) and pa.types.is_floating(field.type.value_type):
if (pa.types.is_list(field.type) or pa.types.is_fixed_size_list(field.type)) and pa.types.is_floating(field.type.value_type):
col_info["is_vector"] = True
col_info["dim"] = None
else:
Expand Down Expand Up @@ -264,7 +277,7 @@ async def get_dataset_rows(
}

# Add special info for vector columns
if pa.types.is_list(field.type) and pa.types.is_floating(field.type.value_type):
if (pa.types.is_list(field.type) or pa.types.is_fixed_size_list(field.type)) and pa.types.is_floating(field.type.value_type):
field_info["vector_info"] = {
"is_vector": True,
"element_type": str(field.type.value_type),
Expand Down Expand Up @@ -406,7 +419,7 @@ async def get_vector_preview(
raise HTTPException(status_code=400, detail=f"Column '{column}' not found")

field = next(field for field in table.schema if field.name == column)
if not (pa.types.is_list(field.type) and pa.types.is_floating(field.type.value_type)):
if not ((pa.types.is_list(field.type) or pa.types.is_fixed_size_list(field.type)) and pa.types.is_floating(field.type.value_type)):
raise HTTPException(status_code=400, detail=f"Column '{column}' is not a vector column")

result = table.to_arrow().select([column]).slice(0, limit)
Expand Down
43 changes: 28 additions & 15 deletions backend/serialize_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

def _serialize_temporal(obj):
"""Convert temporal types to string representation."""
if obj is None:
return None
if isinstance(obj, (datetime, date, time)):
return obj.isoformat()
if isinstance(obj, timedelta):
Expand All @@ -16,23 +18,37 @@ def _serialize_temporal(obj):

def _serialize_pyarrow_scalar(obj):
"""Convert PyArrow scalar types to JSON-serializable format."""
if pa.types.is_binary(obj.type):
return base64.b64encode(obj.as_py()).decode("utf-8")
# PREVENTS "Cannot export buffer from null Arrow Scalar"
if not getattr(obj, "is_valid", True):
return None

# Use pa.types.is_binary to strictly target raw bytes, avoiding StringScalars
if pa.types.is_binary(obj.type) or pa.types.is_large_binary(obj.type):
val = obj.as_py()
return base64.b64encode(val).decode("utf-8") if val else None

if pa.types.is_temporal(obj.type):
return _serialize_temporal(obj.as_py())

if pa.types.is_list(obj.type) or pa.types.is_map(obj.type):
return [serialize_value(item) for item in obj.as_py()]
if pa.types.is_list(obj.type) or pa.types.is_map(obj.type) or pa.types.is_fixed_size_list(obj.type):
val = obj.as_py()
if val is None:
return None
return [serialize_value(item) for item in val]

if pa.types.is_struct(obj.type):
# PREVENTS "'StructScalar' object has no attribute 'field'"
val = obj.as_py()
if val is None:
return None
return {
field.name: serialize_value(obj.field(field.name).as_py())
for field in obj.type
k: serialize_value(v)
for k, v in val.items()
}

if pa.types.is_floating(obj.type):
return float(obj.as_py())
val = obj.as_py()
return float(val) if val is not None else None

return obj.as_py()

Expand All @@ -48,7 +64,7 @@ def _serialize_container(obj):

def _serialize_basic_types(obj):
"""Convert basic Python types to JSON-serializable format."""
if isinstance(obj, (bytes, pa.BinaryScalar)):
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("utf-8")
if isinstance(obj, (datetime, date, time)):
return obj.isoformat()
Expand All @@ -62,13 +78,10 @@ def _serialize_basic_types(obj):
def serialize_value(obj):
"""
Recursively convert objects to JSON-serializable format.

Handles:
- bytes/PyArrow binary: Base64-encoded string
- datetime types: ISO format string
- PyArrow types: Python native types
- nested types: recursive conversion
"""
if obj is None:
return None

# First try basic type conversions
result = _serialize_basic_types(obj)
if result is not obj:
Expand All @@ -83,4 +96,4 @@ def serialize_value(obj):
if isinstance(obj, pa.Scalar):
return _serialize_pyarrow_scalar(obj)

return obj
return obj
Loading