Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions sqlite/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,17 @@ bash ./sqlite/graphsample_delete.sh sqlite/xxx.db 2>&1 | tee sqlite/logs/delete_
bash ./sqlite/graphsample_delete.sh 2>&1 | tee sqlite/logs/delete_$(date +"%Y%m%d_%H%M%S").log
```

## Merge Databases and Upload to Hugging Face
## Merge Databases

```bash
# Usage: python ./sqlite/upload.py --main_db_path <path> --new_db_path <path>
python ./sqlite/upload.py --main_db_path <path> --new_db_path <path>
# Usage: python ./sqlite/merge_db.py --main_db_path <path> --new_db_path <path>
python ./sqlite/merge_db.py --main_db_path sqlite/GraphNet.db --new_db_path sqlite/new.db
```

## Upload to Hugging Face

```bash
python ./sqlite/upload.py
```

**Note:** Set `HF_TOKEN` variable in `upload.py` before running.
9 changes: 9 additions & 0 deletions sqlite/graphsample_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
SubgraphSource,
DimensionGeneralizationSource,
DataTypeGeneralizationSource,
SampleOpNameList,
SampleOpName,
)


Expand Down Expand Up @@ -63,6 +65,13 @@ def delete_graph_sample(db_path: str, relative_model_path: str, repo_uid: str =
datatype_source.deleted = True
datatype_source.delete_at = delete_at

session.query(SampleOpNameList).filter(
SampleOpNameList.sample_uuid == graph_sample.uuid
).update({"deleted": True, "delete_at": delete_at})
session.query(SampleOpName).filter(
SampleOpName.sample_uuid == graph_sample.uuid
).update({"deleted": True, "delete_at": delete_at})

session.commit()
print(f"Successfully deleted: {relative_model_path}")
return True
Expand Down
151 changes: 130 additions & 21 deletions sqlite/graphsample_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
SubgraphSource,
DimensionGeneralizationSource,
DataTypeGeneralizationSource,
SampleOpName,
SampleOpNameList,
)
from sqlalchemy import delete as sql_delete
from sqlalchemy.exc import IntegrityError


Expand Down Expand Up @@ -82,7 +85,7 @@ def insert_subgraph_source(
if not full_graph:
raise ValueError(f"Full graph not found for path: {parent_relative_path}")

range_info = _get_range_info(model_path_prefix, relative_model_path)
range_info = _get_parent_key_and_range(model_path_prefix, relative_model_path)
subgraph_source = SubgraphSource(
subgraph_uuid=subgraph_uuid,
full_graph_uuid=full_graph.uuid,
Expand All @@ -108,26 +111,6 @@ def insert_subgraph_source(
session.close()


def _get_range_info(model_path_prefix: str, relative_model_path: str):
model_path = Path(model_path_prefix) / relative_model_path
subgraph_sources_file = model_path / "subgraph_sources.json"
if not subgraph_sources_file.exists():
return {"start": -1, "end": -1}

try:
with open(subgraph_sources_file) as f:
data = json.load(f)
for key, ranges in data.items():
if isinstance(ranges, list):
r = ranges[0]
if isinstance(r, list) and len(r) == 2:
return {"start": r[0], "end": r[1]}
return {"start": -1, "end": -1}
except (json.JSONDecodeError, KeyError, TypeError, IndexError) as e:
print(f"Warning: Failed to parse {subgraph_sources_file}: {e}")
return {"start": -1, "end": -1}


def get_parent_relative_path(relative_path: str) -> str:
if "_decomposed" not in relative_path:
return None
Expand Down Expand Up @@ -274,6 +257,119 @@ def _get_data_type(model_path_prefix: str, relative_model_path: str):
return "todo"


# SampleOpNameList and SampleOpName insert func
def _get_parent_key_and_range(model_path_prefix: str, relative_model_path: str) -> dict:
model_path = Path(model_path_prefix) / relative_model_path
subgraph_sources_file = model_path / "subgraph_sources.json"
if not subgraph_sources_file.exists():
return {"parent_key": "", "start": -1, "end": -1}

try:
with open(subgraph_sources_file) as f:
data = json.load(f)
for key, ranges in data.items():
if isinstance(ranges, list) and len(ranges) > 0:
r = ranges[0]
if isinstance(r, list) and len(r) == 2:
return {"parent_key": key, "start": r[0], "end": r[1]}
return {"parent_key": "", "start": -1, "end": -1}
except (json.JSONDecodeError, KeyError, TypeError, IndexError) as e:
print(f"Warning: Failed to parse {subgraph_sources_file}: {e}")
return {"parent_key": "", "start": -1, "end": -1}


def insert_sample_op_name_list(
sample_uuid: str,
model_path_prefix: str,
op_names_path_prefix: str,
relative_model_path: str,
db_path: str,
):
if not op_names_path_prefix:
print("op_names_path_prefix not provided, skipping insert_sample_op_name_list")
return

range_info = _get_parent_key_and_range(model_path_prefix, relative_model_path)
parent_key = range_info["parent_key"]
start = range_info["start"]
end = range_info["end"]

if start == -1 or end == -1 or not parent_key:
print(
f"Invalid range info for {relative_model_path}, skipping insert_sample_op_name_list"
)
return

op_size = end - start
op_names_file = Path(op_names_path_prefix) / parent_key / "op_names.txt"
if not op_names_file.exists():
print(
f"op_names.txt not found at {op_names_file}, skipping insert_sample_op_name_list"
)
return

try:
with open(op_names_file) as f:
all_op_names = [line.strip() for line in f.readlines() if line.strip()]
except Exception as e:
print(f"Warning: Failed to read {op_names_file}: {e}")
return

op_start = start
op_end = end
if op_end > len(all_op_names):
print(f"Warning: op_end {op_end} exceeds total ops {len(all_op_names)}")
op_end = len(all_op_names)
if op_start >= op_end:
print(f"Warning: op_start {op_start} >= op_end {op_end}")
return

selected_op_names = all_op_names[op_start:op_end]
op_names_json = json.dumps(
[{"op_name": name, "op_idx": i} for i, name in enumerate(selected_op_names)]
)
session = get_session(db_path)
try:
session.execute(
sql_delete(SampleOpNameList).where(
SampleOpNameList.sample_uuid == sample_uuid
)
)
session.execute(
sql_delete(SampleOpName).where(SampleOpName.sample_uuid == sample_uuid)
)
sample_op_name_list = SampleOpNameList(
sample_uuid=sample_uuid,
op_names_json=op_names_json,
create_at=datetime.now(),
deleted=False,
delete_at=None,
)
session.add(sample_op_name_list)

for idx, op_name in enumerate(selected_op_names):
sample_op_name = SampleOpName(
sample_uuid=sample_uuid,
op_name=op_name,
op_idx=idx,
op_size=op_size,
create_at=datetime.now(),
deleted=False,
delete_at=None,
)
session.add(sample_op_name)

session.commit()
print(
f"Inserted {len(selected_op_names)} op_names for sample_uuid={sample_uuid}"
)
except IntegrityError as e:
session.rollback()
raise e
finally:
session.close()


# main func
def main(args):
data = get_graph_sample_data(
Expand All @@ -294,6 +390,13 @@ def main(args):
relative_model_path=args.relative_model_path,
db_path=args.db_path,
)
insert_sample_op_name_list(
sample_uuid=data["uuid"],
model_path_prefix=args.model_path_prefix,
op_names_path_prefix=args.op_names_path_prefix,
relative_model_path=args.relative_model_path,
db_path=args.db_path,
)
if args.sample_type in ["fusible_graph"]:
insert_dimension_generalization_source(
subgraph_source_data["subgraph_uuid"],
Expand Down Expand Up @@ -358,5 +461,11 @@ def main(args):
default="graphnet.db",
help="Database file path e.g 'graphnet.db'",
)
parser.add_argument(
"--op_names_path_prefix",
type=str,
required=False,
help="Path prefix of op names file",
)
args = parser.parse_args()
main(args)
28 changes: 15 additions & 13 deletions sqlite/graphsample_insert.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ set -x

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
DB_PATH="${1:-${GRAPH_NET_ROOT}/sqlite/GraphNet.db}"
TORCH_MODEL_LIST="graph_net/config/small10_torch_samples_list.txt"
TORCH_MODEL_LIST="graph_net/config/torch_samples_list.txt"
PADDLE_MODEL_LIST="graph_net/config/small10_paddle_samples_list.txt"
TYPICAL_GRAPH_SAMPLES_LIST="20260202_small10/range_decomposed_subgraph_sample_list.txt"
FUSIBLE_GRAPH_SAMPLES_LIST="20260202_small10/workspace_dimension_subgraph_samples/all_dimension_subgraph_list.txt"
SOLE_OP_GRAPH_SAMPLES_LIST="20260202_small10/sole/solo_sample_list.txt"
TYPICAL_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/deduplicated_subgraph_sample_list.txt"
FUSIBLE_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/deduplicated_dimension_generalized_subgraph_sample_list.txt"
SOLE_OP_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/sole/solo_sample_list.txt"
ORDER_VALUE=0


if [ ! -f "$DB_PATH" ]; then
echo "Fail ! No Database ! : $DB_PATH"
exit 1
Expand All @@ -21,7 +20,7 @@ while IFS= read -r model_rel_path; do
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "$GRAPH_NET_ROOT" \
--relative_model_path "$model_rel_path" \
--repo_uid "github_torch_samples" \
--repo_uid "hf_torch_samples" \
--sample_type "full_graph" \
--order_value "$ORDER_VALUE" \
--db_path "$DB_PATH"
Expand All @@ -35,7 +34,7 @@ while IFS= read -r model_rel_path; do
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "$GRAPH_NET_ROOT" \
--relative_model_path "$model_rel_path" \
--repo_uid "github_paddle_samples" \
--repo_uid "hf_paddle_samples" \
--sample_type "full_graph" \
--order_value "$ORDER_VALUE" \
--db_path "$DB_PATH"
Expand All @@ -47,9 +46,10 @@ done < "$PADDLE_MODEL_LIST"
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "${GRAPH_NET_ROOT}/20260202_small10/range_decompose" \
--model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/typical_graph" \
--relative_model_path "$model_rel_path" \
--repo_uid "github_torch_samples" \
--op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \
--repo_uid "hf_torch_samples" \
--sample_type "typical_graph" \
--order_value "$ORDER_VALUE" \
--db_path "$DB_PATH"
Expand All @@ -61,9 +61,10 @@ done < "$TYPICAL_GRAPH_SAMPLES_LIST"
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "${GRAPH_NET_ROOT}/20260202_small10/workspace_dimension_subgraph_samples" \
--model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/fusible_graph" \
--relative_model_path "$model_rel_path" \
--repo_uid "github_torch_samples" \
--op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \
--repo_uid "hf_torch_samples" \
--sample_type "fusible_graph" \
--order_value "$ORDER_VALUE" \
--db_path "$DB_PATH"
Expand All @@ -75,9 +76,10 @@ done < "$FUSIBLE_GRAPH_SAMPLES_LIST"
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "${GRAPH_NET_ROOT}/20260202_small10/sole" \
--model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/sole_op_graph" \
--relative_model_path "$model_rel_path" \
--repo_uid "github_torch_samples" \
--op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \
--repo_uid "hf_torch_samples" \
--sample_type "sole_op_graph" \
--order_value "$ORDER_VALUE" \
--db_path "$DB_PATH"
Expand Down
Loading