Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 161 additions & 1 deletion DashAI/back/api/api_v1/endpoints/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import tempfile
import zipfile
from datetime import datetime, timezone
from typing import Any, Dict

import numpy as np
Expand All @@ -30,7 +31,11 @@
from DashAI.back.dependencies.database.models import Dataset, ModelSession
from DashAI.back.types.inf.type_inference import infer_types
from DashAI.back.types.type_validation import validate_multiple_type_changes
from DashAI.back.types.utils import arrow_to_dashai_schema
from DashAI.back.types.utils import (
arrow_to_dashai_schema,
get_types_from_arrow_metadata,
save_types_in_arrow_metadata,
)

logger = logging.getLogger(__name__)
router = APIRouter()
Expand Down Expand Up @@ -841,6 +846,161 @@ async def update_dataset(
) from e


@router.patch("/{dataset_id}/columns/rename")
@inject
async def rename_dataset_column(
dataset_id: int,
params: schemas.DatasetRenameColumnParams,
session_factory: sessionmaker = Depends(lambda: di["session_factory"]),
):
"""Rename a column in a dataset.

Parameters
----------
dataset_id : int
ID of the dataset to update.
params : DatasetRenameColumnParams
Parameters containing old_name and new_name for the column.
session_factory : Callable[..., ContextManager[Session]]
A factory that creates a context manager that handles a SQLAlchemy session.

Returns
-------
Dict
A dictionary with a success message and updated column types.
"""
with session_factory() as db:
dataset = db.get(Dataset, dataset_id)
if dataset is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Dataset not found"
)

if dataset.status != DatasetStatus.FINISHED:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Dataset is not in finished state or being modified",
)

# Lock the dataset to prevent concurrent modifications
try:
dataset.set_status_as_started()
db.commit()
except exc.SQLAlchemyError as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error locking dataset for modification",
) from e

old_name = params.old_name.strip()
new_name = params.new_name.strip()
if not old_name or not new_name:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Column names cannot be empty",
)
if old_name == new_name:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="New column name must be different from old name",
)

dataset_path = f"{dataset.file_path}/dataset"
arrow_file_path = f"{dataset_path}/data.arrow"
try:
with pa.OSFile(arrow_file_path, "rb") as source:
reader = pa.ipc.open_file(source)
table = reader.read_all()
if old_name not in table.schema.names:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Column '{old_name}' not found in dataset",
)
if new_name in table.schema.names and new_name != old_name:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Column '{new_name}' already exists",
)

column_index = table.schema.get_field_index(old_name)
types_dict = get_types_from_arrow_metadata(table.schema)
new_fields = []
for i, field in enumerate(table.schema):
if i == column_index:
new_fields.append(pa.field(new_name, field.type, field.nullable))
else:
new_fields.append(field)

new_schema = pa.schema(new_fields)

renamed_table = pa.table(
{
new_name if name == old_name else name: table[name]
for name in table.schema.names
},
schema=new_schema,
)
if old_name in types_dict:
types_dict[new_name] = types_dict.pop(old_name)

types_serialized = {col: types_dict[col].to_string() for col in types_dict}
renamed_table = save_types_in_arrow_metadata(
renamed_table, types_serialized
)

with pa.OSFile(arrow_file_path, "wb") as sink:
writer = ipc.new_file(sink, renamed_table.schema)
writer.write_table(renamed_table)
writer.close()

splits_path = f"{dataset_path}/splits.json"
if os.path.exists(splits_path):
with open(splits_path, "r", encoding="utf-8") as f:
splits_data = json.load(f)
if "column_names" in splits_data:
splits_data["column_names"] = [
new_name if name == old_name else name
for name in splits_data["column_names"]
]
if "nan" in splits_data and old_name in splits_data["nan"]:
splits_data["nan"][new_name] = splits_data["nan"].pop(old_name)
with open(splits_path, "w", encoding="utf-8") as f:
json.dump(
splits_data,
f,
indent=2,
sort_keys=True,
ensure_ascii=False,
)

dataset.last_modified = datetime.now(timezone.utc)
dataset.set_status_as_finished()
db.commit()
db.refresh(dataset)
updated_columns = get_columns_spec(dataset_path)
return {
"message": f"Column '{old_name}' renamed to '{new_name}' successfully",
"old_name": old_name,
"new_name": new_name,
"columns": updated_columns,
}
except HTTPException:
# Release the lock before re-raising
dataset.set_status_as_finished()
db.commit()
raise
except Exception as e:
# Release the lock and mark as finished on error
dataset.set_status_as_finished()
db.commit()
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error renaming column: {str(e)}",
) from e


@router.get("/file/")
async def get_dataset_file(
path: str,
Expand Down
5 changes: 5 additions & 0 deletions DashAI/back/api/api_v1/schemas/datasets_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class DatasetUpdateParams(BaseModel):
name: str = None


class DatasetRenameColumnParams(BaseModel):
old_name: str
new_name: str


class DatasetUploadFromNotebookParams(BaseModel):
name: str

Expand Down
30 changes: 28 additions & 2 deletions DashAI/back/job/dataset_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,36 @@ def run(
new_dataset.to_pandas(), method="DashAIPtype"
)

# Cast dataset to inferred types
if "column_renames" in params:
renames = params["column_renames"]
original_names = new_dataset.arrow_table.schema.names
new_names = [renames.get(col, col) for col in original_names]

if len(new_names) != len(set(new_names)):
duplicate_names = set()
seen = set()
for name in new_names:
if name in seen:
duplicate_names.add(name)
else:
seen.add(name)
msg = (
"Invalid column_renames: resulting column names "
"contain duplicates: "
f"{sorted(duplicate_names)}"
)
raise JobError(msg)

arrow_table = new_dataset.arrow_table.rename_columns(new_names)
new_dataset = new_dataset.__class__(
arrow_table,
splits=new_dataset.splits,
types=new_dataset.types,
)
schema = {renames.get(col, col): schema[col] for col in schema}

new_dataset = transform_dataset_with_schema(new_dataset, schema)

# Calculate metadata
new_dataset.compute_metadata()
gc.collect()

Expand Down
17 changes: 17 additions & 0 deletions DashAI/front/src/api/datasets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ export const updateDataset = async (
return response.data;
};

export const renameDatasetColumn = async (
id: number,
oldName: string,
newName: string,
): Promise<{
message: string;
old_name: string;
new_name: string;
columns: object;
}> => {
const response = await api.patch(`${datasetEndpoint}/${id}/columns/rename`, {
old_name: oldName,
new_name: newName,
});
return response.data;
};

export const deleteDataset = async (id: string): Promise<object> => {
const response = await api.delete(`${datasetEndpoint}/${id}`);
return response.data;
Expand Down
8 changes: 1 addition & 7 deletions DashAI/front/src/components/DatasetVisualization.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@ import {
getDatasetInfo,
getDatasetFileFiltered,
} from "../api/datasets";
import DatasetTable from "./notebooks/dataset/DatasetTable";
import { getComponents } from "../api/component";
import { useTourContext } from "./tour/TourProvider";
import { useSnackbar } from "notistack";
import JobQueueWidget from "./jobs/JobQueueWidget";
import { getDatasetStatus } from "../utils/datasetStatus";
import { formatDate } from "../pages/results/constants/formatDate";
import Header from "./notebooks/dataset/header/Header";
import Tooltip from "@mui/material/Tooltip";
Expand Down Expand Up @@ -92,12 +88,10 @@ export default function DatasetVisualization({
fetchDatasetInfo();
}, [dataset.id, dataset.status]);

// fetchPage compatible with server-side filtering
const fetchDatasetPage = useCallback(
async (page, pageSize, filterModel) => {
if (isProcessing) return { rows: [], total: 0 };
try {
// Use getDatasetFile if no filters, else use getDatasetFileFiltered
const hasFilters =
filterModel &&
Array.isArray(filterModel.items) &&
Expand All @@ -122,7 +116,7 @@ export default function DatasetVisualization({
);

const status = dataset.status;
const isProcessing = !(status === 3 || status === 4); // Finished or Error
const isProcessing = !(status === 3 || status === 4);

return (
<>
Expand Down
41 changes: 41 additions & 0 deletions DashAI/front/src/components/datasets/DatasetModal.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,46 @@ function DatasetModal({ open, setOpen, updateDatasets }) {
}
};

const handleColumnRename = (oldName, newName) => {
setColumnsSpec((prevSpec) => {
const updatedSpec = { ...prevSpec };

// Move the spec from old name to new name
if (updatedSpec[oldName]) {
updatedSpec[newName] = updatedSpec[oldName];
delete updatedSpec[oldName];
}

return updatedSpec;
});

// Also update previewData to reflect the new column name
setPreviewData((prevData) => {
if (!prevData || !prevData.schema) return prevData;

const updatedSchema = { ...prevData.schema };
if (updatedSchema[oldName]) {
updatedSchema[newName] = updatedSchema[oldName];
delete updatedSchema[oldName];
}

const updatedSample = prevData.sample.map((row) => {
const newRow = { ...row };
if (oldName in newRow) {
newRow[newName] = newRow[oldName];
delete newRow[oldName];
}
return newRow;
});

return {
...prevData,
schema: updatedSchema,
sample: updatedSample,
};
});
};

const handleInferDataTypes = async (methods) => {
setLoading(true);
const formData = new FormData();
Expand Down Expand Up @@ -283,6 +323,7 @@ function DatasetModal({ open, setOpen, updateDatasets }) {
setNextEnabled={setNextEnabled}
columnsSpec={columnsSpec}
setColumnsSpec={setColumnsSpec}
onColumnRename={handleColumnRename}
/>
)}
</DialogContent>
Expand Down
Loading