Finding all not inheriting C# classes with Roslyn and changing to inheriting from base object (java-like)

1.1k views Asked by At

I'm working on a little Roslyn project that includes changing parse tree and writing changes back to file. I've started with standalone code analyzer and want to build it as a command line app. I've encountered a challenge, though. Working with: Find classes which derive from a specific base class with Roslyn partially, and mostly with: https://github.com/dotnet/roslyn/wiki/Getting-Started-C%23-Syntax-Analysis I've created this small project:

class Program
{
    static void Main(string[] args)
    {
        try
        {
            if (args.Length < 1)
                throw new ArgumentException();
            SyntaxTree tree = CSharpSyntaxTree.ParseText(File.ReadAllText(args[0]));
            var root = (CompilationUnitSyntax) tree.GetRoot();
            var classes = from ClassDeclarationSyntax in root.DescendantNodesAndSelf() select ClassDeclarationSyntax;
            foreach (var c in classes)
            {
                if (/*Not inheriting*/)
                {
                    /*Add inherition*/
                }
            }
            /*Write changes to file*/
        }
        catch (Exception e)
        {
            Console.WriteLine("Fatal error ocured.");
            Console.WriteLine(e.Message);
            Console.WriteLine(e.StackTrace);
        }
    }
}

As comments in code states, I need to check whether the class is inheriting from something or not (and choose second option) then change the parse tree and at last write it to file, for now I'd be glad to know how to check "not inheritance" only, although any directions for step two and three are welcome too. The File to parse and change is supplied by path as program parameter here:

if (args.Length < 1)
    throw new ArgumentException();
SyntaxTree tree = CSharpSyntaxTree.ParseText(File.ReadAllText(args[0]));

**SOLUTION**
With support from received answers I've come up with working app. Here's my code, maybe not perfect but working
class Program
{
    static void Main(string[] args)
    {
        try
        {
            if (args.Length < 1)
                throw new ArgumentException();
            SyntaxTree tree = CSharpSyntaxTree.ParseText(File.ReadAllText(args[0]));
            var root = (CompilationUnitSyntax) tree.GetRoot();
            IdentifierNameSyntax iname = SyntaxFactory.IdentifierName("Object");
            BaseTypeSyntax bts = SyntaxFactory.SimpleBaseType(iname);
            SeparatedSyntaxList<BaseTypeSyntax> ssl = new SeparatedSyntaxList<BaseTypeSyntax>();
            ssl = ssl.Add(bts);
            BaseListSyntax bls = SyntaxFactory.BaseList(ssl);
            bool x = true;
            while(x) //Way to handle all nodes due to impossibility to handle more than one in foreach
            {
                foreach (var c in root.DescendantNodesAndSelf())
                {
                    x = false;
                    var classDeclaration = c as ClassDeclarationSyntax;
                    if (classDeclaration == null)
                        continue;
                    if (classDeclaration.BaseList != null) //Inherits
                        continue;
                    else //Not inherits
                    {
                        root = root.ReplaceNode(classDeclaration, classDeclaration.WithBaseList(bls));
                        x = true;
                        break;
                    }
                }
            }
            if (args.Length > 1) //Write to given file
                using (var sw = new StreamWriter(File.Open(args[1], FileMode.Open)))
                {
                    root.WriteTo(sw);
                }
            else //Overwrite source
                using (var sw = new StreamWriter(File.Open(args[0], FileMode.Open)))
                {
                    root.WriteTo(sw);
                }
        }
        catch (Exception e)
        {
            Console.WriteLine("Fatal error ocured.");
            Console.WriteLine(e.Message);
            Console.WriteLine(e.StackTrace);
        }
    }
}
1

There are 1 answers

4
George Alexandria On BEST ANSWER

ClassDeclarationSyntax has BaseList that contains Types. So you can retrieve information about base classes use these fields:

        foreach (var c in correctRoot.DescendantNodesAndSelf())
        {
            var classDeclaration = c as ClassDeclarationSyntax;
            if (classDeclaration == null)
            {
                continue;
            }
            if (classDeclaration.BaseList?.Types.Count > 0)
            {
                Console.WriteLine("This class has base class or it implements interfaces");
            }
            else
            {
                /*Add inherition*/
            }
        }

Unfortunately, you need extra logic to distinguish that your class has base class or it just implements interfaces. If you want to solve this you need to analyze base object(class/interface) using semantical model to get info about corresponding ISymbol or try to find declaration of these nodes in the syntax tree, if this declaration is defined in your projects/solutions.

Also if you want to add inheritation to class you need to set into BaseList a new created node using SyntaxFactory.SimpleBaseType(...) and SyntaxFactory.BaseList(...)