Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/markdown/blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,15 @@ pub(super) fn push_latex_block_lines(
item_stack: &mut [ItemState],
) {
let rendered = latex::to_unicode(content);
let content_lines: Vec<&str> = rendered.lines().collect();
let all_lines: Vec<&str> = rendered.lines().collect();
let start = all_lines
.iter()
.position(|l| !l.trim().is_empty())
.unwrap_or(0);
let end = all_lines
.iter()
.rposition(|l| !l.trim().is_empty())
.map_or(start, |e| e + 1);
let content_style = Style::default().fg(theme.latex_block_fg);
push_special_block_lines(
lines,
Expand All @@ -412,7 +420,7 @@ pub(super) fn push_latex_block_lines(
item_stack,
SpecialBlockCtx {
label: "latex",
content_lines: &content_lines,
content_lines: &all_lines[start..end],
show_line_numbers: true,
center: false,
make_spans: |line| vec![Span::styled(line.to_string(), content_style)],
Expand Down
167 changes: 153 additions & 14 deletions src/markdown/latex.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use unicode_width::UnicodeWidthStr;

pub(crate) fn to_unicode(text: &str) -> String {
let preprocessed = strip_command_spaces(text);
let converted = unicodeit::replace(&preprocessed);
Expand All @@ -12,15 +14,51 @@ fn strip_command_spaces(input: &str) -> String {

while i < len {
if chars[i] == '\\' && i + 1 < len && chars[i + 1].is_ascii_alphabetic() {
let cmd_start = i + 1;
let mut cmd_end = cmd_start;
while cmd_end < len && chars[cmd_end].is_ascii_alphabetic() {
cmd_end += 1;
}
let cmd = &chars[cmd_start..cmd_end];
let is_left = cmd == ['l', 'e', 'f', 't'];

if is_left || cmd == ['r', 'i', 'g', 'h', 't'] {
if !is_left && result.ends_with(' ') {
result.pop();
}
i = cmd_end;
if i < len && chars[i] == '.' {
i += 1;
} else if is_left && i < len {
result.push(chars[i]);
i += 1;
if i < len && chars[i - 1] == '\\' && !chars[i].is_ascii_alphabetic() {
result.push(chars[i]);
i += 1;
}
if i < len && chars[i] == ' ' {
i += 1;
}
}
continue;
}

result.push('\\');
i += 1;
while i < len && chars[i].is_ascii_alphabetic() {
result.push(chars[i]);
i += 1;
for c in &chars[cmd_start..cmd_end] {
result.push(*c);
}
i = cmd_end;
if i < len && chars[i] == ' ' {
let next = chars.get(i + 1).copied().unwrap_or(' ');
if next.is_ascii_alphabetic() || next == '\\' || next == '{' {
let is_binop = cmd == ['c', 'd', 'o', 't']
|| cmd == ['t', 'i', 'm', 'e', 's']
|| cmd == ['d', 'i', 'v']
|| cmd == ['p', 'm']
|| cmd == ['m', 'p']
|| cmd == ['i', 'n']
|| cmd == ['c', 'a', 'p']
|| cmd == ['c', 'u', 'p'];
if !is_binop && (next.is_ascii_alphabetic() || next == '\\' || next == '{') {
i += 1;
}
}
Expand All @@ -38,6 +76,27 @@ fn postprocess(input: &str) -> String {
let mut i = 0;

while i < input.len() {
if input[i..].starts_with("\\text{") {
let brace_start = i + 6;
if let Some((content, end)) = read_brace_group(input, brace_start) {
result.push_str(content.trim());
i = end;
continue;
}
}

if input[i..].starts_with("\\begin{cases}") {
let after = i + 13;
if let Some(rel) = input[after..].find("\\end{cases}") {
let body = &input[after..after + rel];
let last_line = result.rsplit('\n').next().unwrap_or(&result);
let pad = UnicodeWidthStr::width(last_line);
result.push_str(&render_cases(body, pad));
i = after + rel + 11;
continue;
}
}

if input[i..].starts_with("\\frac{") {
if let Some((output, end)) = parse_frac(input, i) {
result.push_str(&output);
Expand All @@ -49,6 +108,14 @@ fn postprocess(input: &str) -> String {
continue;
}

if input[i..].starts_with("\\binom{") {
if let Some((output, end)) = parse_binom(input, i) {
result.push_str(&output);
i = end;
continue;
}
}

if input[i..].starts_with("√{") {
let brace_start = i + '√'.len_utf8() + 1;
if let Some((group, end)) = read_brace_group(input, brace_start) {
Expand Down Expand Up @@ -99,24 +166,36 @@ fn postprocess(input: &str) -> String {
result
}

fn parse_frac(input: &str, start: usize) -> Option<(String, usize)> {
let after_frac = start + 6;
let (num, after_num) = read_brace_group(input, after_frac)?;
if after_num >= input.len() || input.as_bytes()[after_num] != b'{' {
fn parse_two_groups(
input: &str,
start: usize,
prefix_len: usize,
) -> Option<(String, String, usize)> {
let after = start + prefix_len;
let (a, after_a) = read_brace_group(input, after)?;
if after_a >= input.len() || input.as_bytes()[after_a] != b'{' {
return None;
}
let (den, after_den) = read_brace_group(input, after_num + 1)?;
let num = postprocess(num);
let den = postprocess(den);
let (b, after_b) = read_brace_group(input, after_a + 1)?;
Some((postprocess(a), postprocess(b), after_b))
}

fn parse_frac(input: &str, start: usize) -> Option<(String, usize)> {
let (num, den, end) = parse_two_groups(input, start, 6)?;
let mut out = String::new();
wrap_if_multi(&mut out, &num);
out.push('/');
wrap_if_multi(&mut out, &den);
Some((out, after_den))
Some((out, end))
}

fn parse_binom(input: &str, start: usize) -> Option<(String, usize)> {
let (n, k, end) = parse_two_groups(input, start, 7)?;
Some((format!("C({n},{k})"), end))
}

fn wrap_if_multi(out: &mut String, s: &str) {
if s.chars().count() > 1 {
if s.chars().count() > 1 && s.contains(['+', '-', '−', '=', ' ', '<', '>', '/']) {
out.push('(');
out.push_str(s);
out.push(')');
Expand Down Expand Up @@ -165,6 +244,66 @@ fn read_brace_group(input: &str, start: usize) -> Option<(&str, usize)> {
}
}

fn render_cases(body: &str, prefix_width: usize) -> String {
let rows: Vec<&str> = body
.split("\\\\")
.map(|r| r.trim())
.filter(|r| !r.is_empty())
.collect();

if rows.is_empty() {
return "{ }".to_string();
}

let parsed: Vec<(String, Option<String>)> = rows
.iter()
.map(|row| {
let parts: Vec<&str> = row.splitn(2, '&').collect();
let value = postprocess(parts[0].trim());
let condition = parts.get(1).map(|p| postprocess(p.trim()));
(value, condition)
})
.collect();

let max_first_col = parsed
.iter()
.map(|(v, _)| UnicodeWidthStr::width(v.as_str()))
.max()
.unwrap_or(0);

let padding = " ".repeat(prefix_width);
let mut out = String::new();

for (idx, (value, condition)) in parsed.iter().enumerate() {
let brace = if parsed.len() == 1 {
"{"
} else if idx == 0 {
"\u{23A7}"
} else if idx == parsed.len() - 1 {
"\u{23A9}"
} else {
"\u{23AA}"
};

if idx > 0 {
out.push('\n');
out.push_str(&padding);
}
out.push_str(brace);
out.push(' ');
out.push_str(value);

if let Some(cond) = condition {
let val_width = UnicodeWidthStr::width(value.as_str());
let col_pad = max_first_col - val_width + 2;
out.push_str(&" ".repeat(col_pad));
out.push_str(cond);
}
}

out
}

fn to_superscript(ch: char) -> char {
match ch {
'0' => '⁰',
Expand Down
Loading