From 236aadd66623e46639d6a52c6388230733f1aa62 Mon Sep 17 00:00:00 2001 From: Christian Stefanescu Date: Thu, 5 Mar 2026 15:49:26 +0100 Subject: [PATCH] Add a --count-per-schema flag This overrides the global --count and generates as many entities per given schema. It works with --connected but not with --random-schema. --- ftm_random/main.py | 42 ++++++++++++++++++++------ tests/test_connected.py | 67 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/ftm_random/main.py b/ftm_random/main.py index f09e23a..a518334 100644 --- a/ftm_random/main.py +++ b/ftm_random/main.py @@ -99,6 +99,13 @@ def generate_random_entity(schema_name, entity_pool=None): @click.command() @click.option("--count", default=1, help="Number of entities to generate.") +@click.option( + "--count-per-schema", + "count_per_schema", + default=None, + type=int, + help="Number of entities to generate per schema (overrides --count).", +) @click.option( "--schema", "schemata", @@ -132,7 +139,7 @@ def generate_random_entity(schema_name, entity_pool=None): help="List all available FTM schemas with their type and description.", ) def generate_entities( - count, schemata, random_schema, connected, outfile, list_schemata + count, count_per_schema, schemata, random_schema, connected, outfile, list_schemata ): """Generate random followthemoney entities.""" if list_schemata: @@ -147,6 +154,11 @@ def generate_entities( click.echo(f"{name:<{col_name}} {entity_type:<{col_type}} {description}") return + if count_per_schema is not None and random_schema: + raise click.ClickException( + "--count-per-schema cannot be used with --random-schema." + ) + if random_schema: choices = [ name for name, schema in model.schemata.items() if not schema.abstract @@ -155,9 +167,15 @@ def generate_entities( choices = list(schemata) if not connected: - for _ in range(count): - entity = generate_random_entity(random.choice(choices)) - click.echo(message=json.dumps(entity.to_dict()), file=outfile) + if count_per_schema is not None: + for schema_name in choices: + for _ in range(count_per_schema): + entity = generate_random_entity(schema_name) + click.echo(message=json.dumps(entity.to_dict()), file=outfile) + else: + for _ in range(count): + entity = generate_random_entity(random.choice(choices)) + click.echo(message=json.dumps(entity.to_dict()), file=outfile) return # Connected mode: separate node and edge schemas, generate nodes first, @@ -183,13 +201,17 @@ def generate_entities( "--connected requires at least one non-edge schema (e.g. Person, Company)." ) - # Distribute count across all schemas (nodes first, then edges). + # Determine per-schema counts. all_schemas = node_schemas + edge_schemas - num_schemas = len(all_schemas) - base, remainder = divmod(count, num_schemas) - schema_counts = {name: base for name in all_schemas} - for i in range(remainder): - schema_counts[all_schemas[i]] += 1 + if count_per_schema is not None: + schema_counts = {name: count_per_schema for name in all_schemas} + else: + # Distribute count across all schemas (nodes first, then edges). + num_schemas = len(all_schemas) + base, remainder = divmod(count, num_schemas) + schema_counts = {name: base for name in all_schemas} + for i in range(remainder): + schema_counts[all_schemas[i]] += 1 # Generate node entities and collect their IDs by schema entity_pool = defaultdict(list) diff --git a/tests/test_connected.py b/tests/test_connected.py index 06cc9ac..8a03296 100644 --- a/tests/test_connected.py +++ b/tests/test_connected.py @@ -301,6 +301,73 @@ def test_list_produces_no_entities(self): assert not line.startswith("{") +# --------------------------------------------------------------------------- +# --count-per-schema +# --------------------------------------------------------------------------- + + +class TestCountPerSchema: + def test_single_schema(self): + result = runner.invoke( + generate_entities, + ["--schema", "Person", "--count-per-schema", "3"], + ) + assert result.exit_code == 0 + entities = parse_output(result) + assert len(entities) == 3 + assert all(e["schema"] == "Person" for e in entities) + + def test_multiple_schemas(self): + result = runner.invoke( + generate_entities, + [ + "--schema", + "Person", + "--schema", + "Company", + "--count-per-schema", + "4", + ], + ) + assert result.exit_code == 0 + entities = parse_output(result) + assert len(entities) == 8 + schemas = [e["schema"] for e in entities] + assert schemas.count("Person") == 4 + assert schemas.count("Company") == 4 + + def test_connected_with_count_per_schema(self): + result = runner.invoke( + generate_entities, + [ + "--schema", + "Person", + "--schema", + "Company", + "--schema", + "Directorship", + "--connected", + "--count-per-schema", + "2", + ], + ) + assert result.exit_code == 0 + entities = parse_output(result) + assert len(entities) == 6 + schemas = [e["schema"] for e in entities] + assert schemas.count("Person") == 2 + assert schemas.count("Company") == 2 + assert schemas.count("Directorship") == 2 + + def test_error_with_random_schema(self): + result = runner.invoke( + generate_entities, + ["--random-schema", "--count-per-schema", "5"], + ) + assert result.exit_code != 0 + assert "--count-per-schema" in result.output + + # --------------------------------------------------------------------------- # --random-schema # ---------------------------------------------------------------------------