⚡️ Speed up method ProgressTracker.update by 14%
#19
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 14% (0.14x) speedup for
ProgressTracker.updateinspacy/training/pretrain.py⏱️ Runtime :
206 microseconds→181 microseconds(best of59runs)📝 Explanation and details
The optimization achieves a 13% speedup through two key changes that eliminate overhead in performance-critical operations:
1. Word Count Calculation Optimization
Replaced
sum(len(doc) for doc in docs)with an explicit loop accumulation. While both approaches are O(n), the explicit loop avoids the generator expression overhead and function call overhead ofsum(). This is particularly beneficial when processing many small documents, as shown in test cases where speedups range from 12-30%.2. String Formatting Optimization
In
_smart_round(), replaced the%formatting operator with f-string formatting (f"{figure:.{n_decimal}f}"), which is faster in modern Python. Also optimized the integer conversion by computingint(figure)once and reusing it, avoiding redundant type conversions.3. Minor Time Call Optimization
Cached
time.time()call in a variablenowto avoid calling it twice within the same conditional block, improving both accuracy and performance.Performance Impact Analysis:
The optimizations are most effective for workloads with:
_smart_roundcalls when status updates are triggeredThe 13% overall speedup is significant for training workloads where
ProgressTracker.update()may be called thousands of times during model training, making these micro-optimizations compound into meaningful performance gains. All test cases show consistent improvements, with the largest gains (25-35%) occurring in scenarios with many small document batches.✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import time
from collections import Counter
from typing import Union
imports
import pytest
from spacy.training.pretrain import ProgressTracker
unit tests
---- Basic Test Cases ----
def test_update_returns_none_when_frequency_not_met():
"""Should return None if words_since_update < frequency"""
pt = ProgressTracker(frequency=10)
docs = [["word1", "word2"], ["word3"]] # 3 words
codeflash_output = pt.update(epoch=1, loss=1.5, docs=docs); result = codeflash_output # 4.21μs -> 3.36μs (25.4% faster)
def test_update_returns_status_when_frequency_met():
"""Should return status tuple when words_since_update >= frequency"""
pt = ProgressTracker(frequency=3)
docs = [["word1", "word2"], ["word3"]] # 3 words
# First update: hits frequency exactly
codeflash_output = pt.update(epoch=2, loss=2.0, docs=docs); result = codeflash_output
# Check tuple contents
epoch_out, nr_word, loss_str, loss_per_word_str, wps = result
def test_update_accumulates_loss_and_words():
"""Loss and word count should accumulate over multiple calls"""
pt = ProgressTracker(frequency=10)
docs1 = [["a", "b", "c"], ["d"]]
docs2 = [["e", "f"]]
pt.update(epoch=1, loss=1.0, docs=docs1) # 4.10μs -> 3.46μs (18.7% faster)
pt.update(epoch=1, loss=2.0, docs=docs2) # 1.52μs -> 1.34μs (13.5% faster)
def test_update_multiple_epochs():
"""Should accumulate words/loss per epoch correctly"""
pt = ProgressTracker(frequency=10)
docs1 = [["a", "b"], ["c"]]
docs2 = [["d", "e"]]
pt.update(epoch=1, loss=1.0, docs=docs1) # 3.18μs -> 2.57μs (23.7% faster)
pt.update(epoch=2, loss=2.0, docs=docs2) # 1.66μs -> 1.34μs (23.6% faster)
---- Edge Test Cases ----
def test_update_zero_loss_and_empty_docs():
"""Should handle zero loss and empty docs gracefully"""
pt = ProgressTracker(frequency=1)
docs = []
codeflash_output = pt.update(epoch=1, loss=0.0, docs=docs); result = codeflash_output # 3.01μs -> 2.34μs (29.0% faster)
def test_update_docs_with_empty_lists():
"""Should count only actual words in docs, not empty lists"""
pt = ProgressTracker(frequency=2)
docs = [[], ["a"], []]
codeflash_output = pt.update(epoch=1, loss=1.0, docs=docs); result = codeflash_output # 3.16μs -> 2.57μs (23.2% faster)
def test_update_negative_loss():
"""Should handle negative loss values (e.g., for regularization)"""
pt = ProgressTracker(frequency=1)
docs = [["a", "b"]]
codeflash_output = pt.update(epoch=1, loss=-1.5, docs=docs); result = codeflash_output
def test_update_large_loss_and_word_count():
"""Should handle very large loss and word counts"""
pt = ProgressTracker(frequency=100)
docs = [["a"] * 100] # 100 words
codeflash_output = pt.update(epoch=1, loss=1e9, docs=docs); result = codeflash_output
def test_update_frequency_exact_boundary():
"""Should trigger update exactly at frequency boundary"""
pt = ProgressTracker(frequency=5)
docs1 = [["a", "b"]]
docs2 = [["c", "d", "e"]]
pt.update(epoch=1, loss=1.0, docs=docs1) # 2 words
codeflash_output = pt.update(epoch=1, loss=2.0, docs=docs2); result = codeflash_output # 3 words, total 5
def test_update_multiple_updates_and_loss_per_word():
"""Should correctly calculate loss_per_word over multiple updates"""
pt = ProgressTracker(frequency=3)
docs1 = [["a", "b"]]
docs2 = [["c"]]
pt.update(epoch=1, loss=1.0, docs=docs1) # 2 words
codeflash_output = pt.update(epoch=1, loss=2.0, docs=docs2); result = codeflash_output # 1 word, triggers update
# Next update, prev_loss should be 3.0
docs3 = [["d", "e", "f"]]
codeflash_output = pt.update(epoch=1, loss=3.0, docs=docs3); result2 = codeflash_output # 3 words, triggers update
def test_update_with_non_integer_loss():
"""Should handle float loss values and round correctly"""
pt = ProgressTracker(frequency=2)
docs = [["a"], ["b"]]
codeflash_output = pt.update(epoch=1, loss=1.23456, docs=docs); result = codeflash_output
def test_update_with_zero_frequency():
"""Should always trigger update if frequency=0"""
pt = ProgressTracker(frequency=0)
docs = [["a"], ["b"]]
codeflash_output = pt.update(epoch=1, loss=1.0, docs=docs); result = codeflash_output
# Should update every call
codeflash_output = pt.update(epoch=1, loss=2.0, docs=docs); result2 = codeflash_output
---- Large Scale Test Cases ----
def test_update_large_number_of_docs_and_words():
"""Should handle large batches efficiently and correctly"""
pt = ProgressTracker(frequency=500)
docs = [["word"] * 10 for _ in range(50)] # 500 words
codeflash_output = pt.update(epoch=1, loss=100.0, docs=docs); result = codeflash_output
def test_update_multiple_large_batches_accumulate():
"""Should accumulate counts and trigger multiple updates as needed"""
pt = ProgressTracker(frequency=400)
docs1 = [["a"] * 100 for _ in range(2)] # 200 words
docs2 = [["b"] * 100 for _ in range(3)] # 300 words
pt.update(epoch=1, loss=50.0, docs=docs1) # 200 words, no update
codeflash_output = pt.update(epoch=1, loss=60.0, docs=docs2); result = codeflash_output # 300 words, total 500
def test_update_scalability_with_many_epochs():
"""Should handle many epochs and track words per epoch correctly"""
pt = ProgressTracker(frequency=1000)
for epoch in range(10):
docs = [["w"] * 100 for _ in range(5)] # 500 words per epoch
pt.update(epoch=epoch, loss=10.0, docs=docs)
# Should have 500 words per epoch
for epoch in range(10):
pass
def test_update_performance_under_large_scale():
"""Should not degrade performance with large batches (under 1000 words)"""
pt = ProgressTracker(frequency=1000)
docs = [["x"] * 50 for _ in range(20)] # 1000 words
start_time = time.time()
codeflash_output = pt.update(epoch=1, loss=500.0, docs=docs); result = codeflash_output
duration = time.time() - start_time
def test_update_multiple_calls_large_scale():
"""Should handle multiple large calls and accumulate correctly"""
pt = ProgressTracker(frequency=900)
docs1 = [["a"] * 300 for _ in range(2)] # 600 words
docs2 = [["b"] * 100 for _ in range(3)] # 300 words
pt.update(epoch=1, loss=200.0, docs=docs1) # 600 words, no update
codeflash_output = pt.update(epoch=1, loss=100.0, docs=docs2); result = codeflash_output # 300 words, total 900
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import time
from collections import Counter
from typing import Union
imports
import pytest
from spacy.training.pretrain import ProgressTracker
unit tests
----------- BASIC TEST CASES ------------
def test_update_returns_none_when_frequency_not_met():
"""Should return None if nr_word < frequency after update."""
pt = ProgressTracker(frequency=10)
# docs: 2 docs, each with 2 words; total 4 words
codeflash_output = pt.update(epoch=1, loss=1.5, docs=[["a", "b"], ["c", "d"]]); result = codeflash_output # 4.22μs -> 3.31μs (27.3% faster)
# Second batch, still not enough
codeflash_output = pt.update(epoch=1, loss=2.0, docs=[["e", "f"]]); result = codeflash_output # 1.41μs -> 1.44μs (2.29% slower)
def test_update_returns_status_when_frequency_met(monkeypatch):
"""Should return a tuple when nr_word >= frequency after update."""
pt = ProgressTracker(frequency=5)
# Patch time so that wps calculation is deterministic
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
# First batch: 3 words
pt.update(epoch=2, loss=2.0, docs=[["a", "b", "c"]]) # 3.29μs -> 2.60μs (26.9% faster)
# Second batch: 2 words, triggers update (total 5 words)
codeflash_output = pt.update(epoch=2, loss=1.0, docs=[["d"], ["e"]]); result = codeflash_output # 8.28μs -> 7.07μs (17.1% faster)
def test_words_per_epoch_accumulates_correctly():
"""words_per_epoch should accumulate word counts per epoch."""
pt = ProgressTracker(frequency=100)
pt.update(epoch=1, loss=1.0, docs=[["a", "b"], ["c"]]) # 3.13μs -> 2.45μs (27.8% faster)
pt.update(epoch=1, loss=2.0, docs=[["d", "e"]]) # 1.45μs -> 1.29μs (12.0% faster)
pt.update(epoch=2, loss=1.0, docs=[["f"]]) # 1.20μs -> 963ns (24.3% faster)
def test_loss_accumulates():
"""Loss and epoch_loss should accumulate correctly."""
pt = ProgressTracker(frequency=1000)
pt.update(epoch=1, loss=1.0, docs=[["a"]]) # 2.79μs -> 2.20μs (26.8% faster)
pt.update(epoch=1, loss=2.5, docs=[["b", "c"]]) # 1.47μs -> 1.18μs (24.3% faster)
def test_update_with_empty_docs():
"""Should handle empty docs list (zero words, loss still accumulates)."""
pt = ProgressTracker(frequency=1)
codeflash_output = pt.update(epoch=1, loss=1.0, docs=[]); result = codeflash_output # 2.66μs -> 2.29μs (16.4% faster)
----------- EDGE TEST CASES ------------
def test_update_with_zero_loss_and_zero_words(monkeypatch):
"""Should handle batch with zero loss and zero words."""
pt = ProgressTracker(frequency=1)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
codeflash_output = pt.update(epoch=1, loss=0.0, docs=[[]]); result = codeflash_output # 3.09μs -> 2.33μs (32.7% faster)
def test_update_with_negative_loss(monkeypatch):
"""Should correctly accumulate negative loss values."""
pt = ProgressTracker(frequency=2)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=-1.0, docs=[["a"]]) # 2.87μs -> 2.40μs (19.6% faster)
codeflash_output = pt.update(epoch=1, loss=-2.0, docs=[["b"]]); result = codeflash_output # 7.53μs -> 6.60μs (14.1% faster)
def test_update_with_large_loss(monkeypatch):
"""Should handle very large loss values without error."""
pt = ProgressTracker(frequency=2)
fake_time = pt.last_time + 2.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=1e10, docs=[["a"]]) # 2.81μs -> 2.21μs (27.5% faster)
codeflash_output = pt.update(epoch=1, loss=1e10, docs=[["b"]]); result = codeflash_output # 4.77μs -> 4.34μs (9.83% faster)
def test_update_with_docs_of_various_lengths(monkeypatch):
"""Should correctly sum words in docs of different lengths."""
pt = ProgressTracker(frequency=6)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=1.0, docs=[["a"], ["b", "c"], [], ["d", "e", "f"]]) # 8.38μs -> 6.96μs (20.5% faster)
# 1 + 2 + 0 + 3 = 6 words, triggers update
codeflash_output = pt.update(epoch=1, loss=2.0, docs=[]); result = codeflash_output # 1.70μs -> 1.40μs (21.8% faster)
def test_update_with_non_integer_loss(monkeypatch):
"""Should handle float loss values and round them appropriately."""
pt = ProgressTracker(frequency=3)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=1.2345, docs=[["a"]]) # 2.79μs -> 2.19μs (27.5% faster)
pt.update(epoch=1, loss=2.3456, docs=[["b", "c"]]) # 6.28μs -> 5.62μs (11.8% faster)
codeflash_output = pt.update(epoch=1, loss=0.0, docs=[]); result = codeflash_output # 1.34μs -> 1.16μs (15.7% faster)
def test_update_with_empty_doc(monkeypatch):
"""Should handle docs containing empty lists."""
pt = ProgressTracker(frequency=2)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=1.0, docs=[[]]) # 2.83μs -> 2.26μs (24.9% faster)
codeflash_output = pt.update(epoch=1, loss=1.0, docs=[["a", "b"]]); result = codeflash_output # 6.02μs -> 5.43μs (10.8% faster)
----------- LARGE SCALE TEST CASES ------------
def test_update_large_number_of_docs(monkeypatch):
"""Should handle large batches efficiently and accurately."""
pt = ProgressTracker(frequency=1000)
fake_time = pt.last_time + 5.0
monkeypatch.setattr(time, "time", lambda: fake_time)
# 500 docs, each with 2 words = 1000 words
docs = [["word1", "word2"] for _ in range(500)]
codeflash_output = pt.update(epoch=1, loss=500.0, docs=docs); result = codeflash_output # 22.7μs -> 21.4μs (6.21% faster)
def test_update_multiple_epochs_large_scale(monkeypatch):
"""Should accumulate words and loss correctly across multiple epochs."""
pt = ProgressTracker(frequency=500)
fake_time = pt.last_time + 2.0
monkeypatch.setattr(time, "time", lambda: fake_time)
docs1 = [["a"] * 5 for _ in range(50)] # 250 words
docs2 = [["b"] * 5 for _ in range(50)] # 250 words
pt.update(epoch=1, loss=100.0, docs=docs1) # 4.49μs -> 3.86μs (16.3% faster)
codeflash_output = pt.update(epoch=2, loss=200.0, docs=docs2); result = codeflash_output # 7.94μs -> 7.23μs (9.86% faster)
def test_update_with_max_docs_and_words(monkeypatch):
"""Should handle the largest reasonable batch size (999 docs, 1 word each)."""
pt = ProgressTracker(frequency=999)
fake_time = pt.last_time + 10.0
monkeypatch.setattr(time, "time", lambda: fake_time)
docs = [["w"] for _ in range(999)]
codeflash_output = pt.update(epoch=1, loss=999.0, docs=docs); result = codeflash_output # 36.6μs -> 34.7μs (5.43% faster)
def test_update_multiple_calls_large_scale(monkeypatch):
"""Should correctly accumulate over many calls and trigger update only when frequency is met."""
pt = ProgressTracker(frequency=1000)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
# 10 batches of 100 words each
for i in range(9):
codeflash_output = pt.update(epoch=1, loss=10.0, docs=[["w"] * 100]); result = codeflash_output # 10.7μs -> 8.94μs (19.8% faster)
# 10th batch triggers update
codeflash_output = pt.update(epoch=1, loss=10.0, docs=[["w"] * 100]); result = codeflash_output # 5.38μs -> 4.94μs (8.76% faster)
----------- FUNCTIONALITY & MUTATION TESTS ------------
def test_update_mutation_nr_word(monkeypatch):
"""Mutation: If nr_word is not incremented properly, status[1] will be wrong."""
pt = ProgressTracker(frequency=3)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=1.0, docs=[["a", "b"]]) # 2.86μs -> 2.20μs (30.2% faster)
codeflash_output = pt.update(epoch=1, loss=2.0, docs=[["c"]]); result = codeflash_output # 5.74μs -> 5.41μs (6.16% faster)
def test_update_mutation_loss(monkeypatch):
"""Mutation: If loss is not accumulated, status[2] will be wrong."""
pt = ProgressTracker(frequency=2)
fake_time = pt.last_time + 1.0
monkeypatch.setattr(time, "time", lambda: fake_time)
pt.update(epoch=1, loss=1.0, docs=[["a"]]) # 2.83μs -> 2.23μs (27.1% faster)
codeflash_output = pt.update(epoch=1, loss=2.0, docs=[["b"]]); result = codeflash_output # 5.59μs -> 5.16μs (8.41% faster)
def test_update_mutation_words_per_epoch():
"""Mutation: If words_per_epoch is not updated, counts will be wrong."""
pt = ProgressTracker(frequency=5)
pt.update(epoch=1, loss=1.0, docs=[["a", "b"]])
pt.update(epoch=2, loss=2.0, docs=[["c", "d", "e"]])
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-ProgressTracker.update-mhwsd6cgand push.