How to modify .py file using LibCST?

256 views Asked by At

Hello.

I have some Python source code that I am trying to modify using LibCST. In short, I need to import 3 modules into the source .py file at a specific location. I am using LibCST instead of AST because I need to preserve comments and whitespaces. Examples are provided below of what I'm trying to achieve.

I've read the documentation and tried to utilize their examples but I cannot wrap my head around it. Any help at all is greatly appreciated!

Source File

import car
import horse
import cheese

x = 1

Modified Source File

import car
import horse
import cheese

import new
import packages

x = 1
1

There are 1 answers

0
zaicruvoir1rominet On

How do I use LibCST ?

1. Check the "end result" CST

You can do this using libcst.tool.dump, which displays the CST.

import libcst.tool

code_sample = r'''
import car
import horse
import cheese

import new
import packages

x = 1
'''

cst = libcst.parse_module(final_code)

# Display only the relevant parts of the CST
print(libcst.tool.dump(cst))

# If you want to print ALL the CST (very heavy)
print(cst)

Output:

Module(
  body=[

    # lots of imports statements
    SimpleStatementLine(
      body=[
        Import(
          names=[
            ImportAlias(
              name=Name(
                value='packages',
              ),
            ),
          ],
        ),
      ],
    ),

    # and x = 1
    SimpleStatementLine(
      body=[
        Assign(
          targets=[
            AssignTarget(
              target=Name(
                value='x',
              ),
            ),
          ],
          value=Integer(
            value='1',
          ),
        ),
      ],
    ),
  ],
)

It's just like reading code, to avoid getting lost I just focus on the important things:
It's a Python file (Module),
which contains a couple of lines (SimpleStatementLine),
some contain imports (ImportAlias), with package names (Name=...),
and the last line contains an assignment and other stuff we don't care about.

2. Write a code transformer

Documentation: How to transform code using libcst - tutorial

The easiest way to add code is by recognizing a code pattern where code needs to be inserted. For instance, if code needs to be inserted after import cheese statements: in the transformer, code is inserted when visiting a SimpleStatementLine, containing an ImportAlias which Nameis "cheese".

You can use Python 3.10's structural pattern matching.

class AddCode(libcst.CSTTransformer):
    """CST to add code (ugly, just for the example)"""

    def __init__(self, add_code: libcst.CSTNode | Sequence[libcst.CSTNode]) -> None:
        super().__init__()
        self._code_to_add = add_code

    def leave_SimpleStatementLine(
        self, original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine
    ) -> libcst.BaseStatement | libcst.FlattenSentinel[libcst.BaseStatement] | libcst.RemovalSentinel:
        # When visiting a simple statement line
        match updated_node:
            # note: this is just a plain copy & paste of the pattern as seen in the CST
            case libcst.SimpleStatementLine(
              body=[
                libcst.Import(
                  names=[
                    libcst.ImportAlias(
                      name=libcst.Name(
                        value='cheese',
                      ),
                    ),
                  ],
                ),
              ],
            ):
                # If the line matches the import pattern exactly, add code to it
                return libcst.FlattenSentinel([updated_node] + list(self._code_to_add))

        # Else leave everything as is
        return updated_node

3. And finally

code_sample = r'''
import car
import horse
import cheese

x = 1
'''

code_to_add = r'''
import new
import packages
'''

cst = libcst.parse_module(code_sample)
to_add = libcst.parse_module(code_to_add).body

# Apply transformer to CST
updated_cst = cst.visit(AddCode(to_add))

print(updated_cst.code)

Output:

import car
import horse
import cheese
import new
import packages

x = 1