diff --git a/babi/file.py b/babi/file.py index 14ae464b..a71056c2 100644 --- a/babi/file.py +++ b/babi/file.py @@ -221,6 +221,21 @@ def __init__( self.selection = Selection() self._file_hls: tuple[FileHL, ...] = () + def refresh_syntax(self) -> None: + file_hls = [] + for factory in self._hl_factories: + if self.filename is not None: + hl = factory.file_highlighter(self.filename, self.buf[0]) + file_hls.append(hl) + else: + file_hls.append(factory.blank_file_highlighter()) + self._file_hls = ( + *file_hls, + self._trailing_whitespace, self._replace_hl, self.selection, + ) + for file_hl in self._file_hls: + file_hl.register_callbacks(self.buf) + def ensure_loaded( self, status: Status, @@ -255,20 +270,7 @@ def ensure_loaded( status.update(f'mixed newlines will be converted to {self.nl!r}') self.modified = True - file_hls = [] - for factory in self._hl_factories: - if self.filename is not None: - hl = factory.file_highlighter(self.filename, self.buf[0]) - file_hls.append(hl) - else: - file_hls.append(factory.blank_file_highlighter()) - self._file_hls = ( - *file_hls, - self._trailing_whitespace, self._replace_hl, self.selection, - ) - for file_hl in self._file_hls: - file_hl.register_callbacks(self.buf) - + self.refresh_syntax() self.go_to_line(self.initial_line, margin) def __repr__(self) -> str: diff --git a/babi/screen.py b/babi/screen.py index c629c46d..9c2b7184 100644 --- a/babi/screen.py +++ b/babi/screen.py @@ -565,6 +565,16 @@ def save(self) -> PromptResult | None: else: self.file.filename = filename + if os.path.exists(self.file.filename): + with open(self.file.filename) as f: + x = f.readline()[:-1] + refresh_syntax = ( + x != self.file.buf[0] and + self.file.buf[0].startswith('#!/') or x.startswith('#!/') + ) + else: + refresh_syntax = True + if not os.path.isfile(self.file.filename): sha256: str | None = None else: @@ -596,6 +606,10 @@ def save(self) -> PromptResult | None: lines = 'lines' if num_lines != 1 else 'line' self.status.update(f'saved! ({num_lines} {lines} written)') self.file.reset_modified_state() + + if refresh_syntax: + self.file.refresh_syntax() + return None def save_filename(self) -> PromptResult | None: