Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions ftm_random/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def generate_random_entity(schema_name, entity_pool=None):

@click.command()
@click.option("--count", default=1, help="Number of entities to generate.")
@click.option(
"--count-per-schema",
"count_per_schema",
default=None,
type=int,
help="Number of entities to generate per schema (overrides --count).",
)
@click.option(
"--schema",
"schemata",
Expand Down Expand Up @@ -132,7 +139,7 @@ def generate_random_entity(schema_name, entity_pool=None):
help="List all available FTM schemas with their type and description.",
)
def generate_entities(
count, schemata, random_schema, connected, outfile, list_schemata
count, count_per_schema, schemata, random_schema, connected, outfile, list_schemata
):
"""Generate random followthemoney entities."""
if list_schemata:
Expand All @@ -147,6 +154,11 @@ def generate_entities(
click.echo(f"{name:<{col_name}} {entity_type:<{col_type}} {description}")
return

if count_per_schema is not None and random_schema:
raise click.ClickException(
"--count-per-schema cannot be used with --random-schema."
)

if random_schema:
choices = [
name for name, schema in model.schemata.items() if not schema.abstract
Expand All @@ -155,9 +167,15 @@ def generate_entities(
choices = list(schemata)

if not connected:
for _ in range(count):
entity = generate_random_entity(random.choice(choices))
click.echo(message=json.dumps(entity.to_dict()), file=outfile)
if count_per_schema is not None:
for schema_name in choices:
for _ in range(count_per_schema):
entity = generate_random_entity(schema_name)
click.echo(message=json.dumps(entity.to_dict()), file=outfile)
else:
for _ in range(count):
entity = generate_random_entity(random.choice(choices))
click.echo(message=json.dumps(entity.to_dict()), file=outfile)
return

# Connected mode: separate node and edge schemas, generate nodes first,
Expand All @@ -183,13 +201,17 @@ def generate_entities(
"--connected requires at least one non-edge schema (e.g. Person, Company)."
)

# Distribute count across all schemas (nodes first, then edges).
# Determine per-schema counts.
all_schemas = node_schemas + edge_schemas
num_schemas = len(all_schemas)
base, remainder = divmod(count, num_schemas)
schema_counts = {name: base for name in all_schemas}
for i in range(remainder):
schema_counts[all_schemas[i]] += 1
if count_per_schema is not None:
schema_counts = {name: count_per_schema for name in all_schemas}
else:
# Distribute count across all schemas (nodes first, then edges).
num_schemas = len(all_schemas)
base, remainder = divmod(count, num_schemas)
schema_counts = {name: base for name in all_schemas}
for i in range(remainder):
schema_counts[all_schemas[i]] += 1

# Generate node entities and collect their IDs by schema
entity_pool = defaultdict(list)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,73 @@ def test_list_produces_no_entities(self):
assert not line.startswith("{")


# ---------------------------------------------------------------------------
# --count-per-schema
# ---------------------------------------------------------------------------


class TestCountPerSchema:
def test_single_schema(self):
result = runner.invoke(
generate_entities,
["--schema", "Person", "--count-per-schema", "3"],
)
assert result.exit_code == 0
entities = parse_output(result)
assert len(entities) == 3
assert all(e["schema"] == "Person" for e in entities)

def test_multiple_schemas(self):
result = runner.invoke(
generate_entities,
[
"--schema",
"Person",
"--schema",
"Company",
"--count-per-schema",
"4",
],
)
assert result.exit_code == 0
entities = parse_output(result)
assert len(entities) == 8
schemas = [e["schema"] for e in entities]
assert schemas.count("Person") == 4
assert schemas.count("Company") == 4

def test_connected_with_count_per_schema(self):
result = runner.invoke(
generate_entities,
[
"--schema",
"Person",
"--schema",
"Company",
"--schema",
"Directorship",
"--connected",
"--count-per-schema",
"2",
],
)
assert result.exit_code == 0
entities = parse_output(result)
assert len(entities) == 6
schemas = [e["schema"] for e in entities]
assert schemas.count("Person") == 2
assert schemas.count("Company") == 2
assert schemas.count("Directorship") == 2

def test_error_with_random_schema(self):
result = runner.invoke(
generate_entities,
["--random-schema", "--count-per-schema", "5"],
)
assert result.exit_code != 0
assert "--count-per-schema" in result.output


# ---------------------------------------------------------------------------
# --random-schema
# ---------------------------------------------------------------------------
Expand Down