Add basic support for union, base the code much on Struct. Add a testcase.
[gnome.gobject-introspection] / giscanner / transformer.py
1 # -*- Mode: Python -*-
2 # GObject-Introspection - a framework for introspecting GObject libraries
3 # Copyright (C) 2008  Johan Dahlin
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # of the License, or (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
18 # 02110-1301, USA.
19 #
20
21 from giscanner.ast import (Callback, Enum, Function, Namespace, Member,
22                            Parameter, Return, Sequence, Struct, Field,
23                            Type, Alias, Interface, Class, Node, Union,
24                            type_name_from_ctype, type_names)
25 from .glibast import GLibBoxed
26 from giscanner.sourcescanner import (
27     SourceSymbol, ctype_name, CTYPE_POINTER,
28     CTYPE_BASIC_TYPE, CTYPE_UNION, CTYPE_ARRAY, CTYPE_TYPEDEF,
29     CTYPE_VOID, CTYPE_ENUM, CTYPE_FUNCTION, CTYPE_STRUCT,
30     CSYMBOL_TYPE_FUNCTION, CSYMBOL_TYPE_TYPEDEF, CSYMBOL_TYPE_STRUCT,
31     CSYMBOL_TYPE_ENUM, CSYMBOL_TYPE_UNION, CSYMBOL_TYPE_OBJECT,
32     CSYMBOL_TYPE_MEMBER)
33 from .odict import odict
34 from .utils import strip_common_prefix
35
36
37 class SkipError(Exception):
38     pass
39
40
41 class Names(object):
42     names = property(lambda self: self._names)
43     aliases = property(lambda self: self._aliases)
44     type_names = property(lambda self: self._type_names)
45     ctypes = property(lambda self: self._ctypes)
46
47     def __init__(self):
48         super(Names, self).__init__()
49         self._names = odict() # Maps from GIName -> (namespace, node)
50         self._aliases = {} # Maps from GIName -> GIName
51         self._type_names = {} # Maps from GTName -> (namespace, node)
52         self._ctypes = {} # Maps from CType -> (namespace, node)
53
54
55 class Transformer(object):
56
57     def __init__(self, generator, namespace_name):
58         self.generator = generator
59         self._namespace = Namespace(namespace_name)
60         self._names = Names()
61         self._typedefs_ns = {}
62         self._strip_prefix = ''
63
64     def get_names(self):
65         return self._names
66
67     def set_strip_prefix(self, strip_prefix):
68         self._strip_prefix = strip_prefix
69
70     def parse(self):
71         nodes = []
72         for symbol in self.generator.get_symbols():
73             node = self._traverse_one(symbol)
74             self._add_node(node)
75         return self._namespace
76
77     def register_include(self, filename):
78         if filename.endswith('.gir'):
79             from .girparser import GIRParser
80             parser = GIRParser(filename)
81         elif filename.endswith('.gidl'):
82             from .gidlparser import GIDLParser
83             parser = GIDLParser(filename)
84         else:
85             raise NotImplementedError(filename)
86         nsname = parser.get_namespace_name()
87         for node in parser.get_nodes():
88             if isinstance(node, Alias):
89                 self._names.aliases[node.name] = (nsname, node)
90             elif isinstance(node, (GLibBoxed, Interface, Class)):
91                 self._names.type_names[node.type_name] = (nsname, node)
92             self._names.names[node.name] = (nsname, node)
93             if hasattr(node, 'ctype'):
94                 self._names.ctypes[node.ctype] = (nsname, node)
95             elif hasattr(node, 'symbol'):
96                 self._names.ctypes[node.symbol] = (nsname, node)
97
98     def strip_namespace_object(self, name):
99         prefix = self._namespace.name.lower()
100         if len(name) > len(prefix) and name.lower().startswith(prefix):
101             return name[len(prefix):]
102         return self._remove_prefix(name)
103
104     # Private
105
106     def _add_node(self, node):
107         if node is None:
108             return
109         if node.name.startswith('_'):
110             return
111         self._namespace.nodes.append(node)
112         self._names.names[node.name] = (None, node)
113
114     def _strip_namespace_func(self, name):
115         prefix = self._namespace.name.lower() + '_'
116         if name.lower().startswith(prefix):
117             name = name[len(prefix):]
118         return self._remove_prefix(name)
119
120     def _remove_prefix(self, name):
121         # when --strip-prefix=g:
122         #   GHashTable -> HashTable
123         #   g_hash_table_new -> hash_table_new
124         if name.lower().startswith(self._strip_prefix.lower()):
125             name = name[len(self._strip_prefix):]
126
127         while name.startswith('_'):
128             name = name[1:]
129         return name
130
131     def _traverse_one(self, symbol, stype=None):
132         assert isinstance(symbol, SourceSymbol), symbol
133
134         if stype is None:
135             stype = symbol.type
136         if stype == CSYMBOL_TYPE_FUNCTION:
137             try:
138                 return self._create_function(symbol)
139             except SkipError:
140                 return
141         elif stype == CSYMBOL_TYPE_TYPEDEF:
142             return self._create_typedef(symbol)
143         elif stype == CSYMBOL_TYPE_STRUCT:
144             return self._create_struct(symbol)
145         elif stype == CSYMBOL_TYPE_ENUM:
146             return self._create_enum(symbol)
147         elif stype == CSYMBOL_TYPE_OBJECT:
148             return self._create_object(symbol)
149         elif stype == CSYMBOL_TYPE_MEMBER:
150             return self._create_member(symbol)
151         elif stype == CSYMBOL_TYPE_UNION:
152             return self._create_union(symbol)
153         else:
154             raise NotImplementedError(
155                 'Transformer: unhandled symbol: %r' % (symbol, ))
156
157     def _create_enum(self, symbol):
158         members = []
159         for child in symbol.base_type.child_list:
160             name = strip_common_prefix(symbol.ident, child.ident).lower()
161             members.append(Member(name,
162                                   child.const_int,
163                                   child.ident))
164
165         enum_name = self.strip_namespace_object(symbol.ident)
166         enum_name = symbol.ident[-len(enum_name):]
167         enum_name = self._remove_prefix(enum_name)
168         enum = Enum(enum_name, symbol.ident, members)
169         self._names.type_names[symbol.ident] = (None, enum)
170         return enum
171
172     def _create_object(self, symbol):
173         return Member(symbol.ident, symbol.base_type.name,
174                       symbol.ident)
175
176     def _create_function(self, symbol):
177         directives = symbol.directives()
178         parameters = list(self._create_parameters(
179             symbol.base_type, directives))
180         return_ = self._create_return(symbol.base_type.base_type,
181                                       directives.get('return', []))
182         name = self._remove_prefix(symbol.ident)
183         name = self._strip_namespace_func(name)
184         return Function(name, return_, parameters, symbol.ident)
185
186     def _create_source_type(self, source_type):
187         if source_type is None:
188             return 'None'
189         if source_type.type == CTYPE_VOID:
190             value = 'void'
191         elif source_type.type == CTYPE_BASIC_TYPE:
192             value = source_type.name
193         elif source_type.type == CTYPE_TYPEDEF:
194             value = source_type.name
195         elif source_type.type == CTYPE_ARRAY:
196             return self._create_source_type(source_type.base_type)
197         elif source_type.type == CTYPE_POINTER:
198             value = self._create_source_type(source_type.base_type) + '*'
199         else:
200             print 'TRANSFORMER: Unhandled source type %r' % (
201                 source_type, )
202             value = 'any'
203         return value
204
205     def _create_parameters(self, base_type, options=None):
206         if not options:
207             options = {}
208         for child in base_type.child_list:
209             yield self._create_parameter(
210                 child, options.get(child.ident, []))
211
212     def _create_member(self, symbol):
213         ctype = symbol.base_type.type
214         if (ctype == CTYPE_POINTER and
215             symbol.base_type.base_type.type == CTYPE_FUNCTION):
216             node = self._create_callback(symbol)
217         else:
218             ftype = self._create_type(symbol.base_type)
219             node = Field(symbol.ident, ftype, symbol.ident)
220         return node
221
222     def _create_typedef(self, symbol):
223         ctype = symbol.base_type.type
224         if (ctype == CTYPE_POINTER and
225             symbol.base_type.base_type.type == CTYPE_FUNCTION):
226             node = self._create_callback(symbol)
227         elif ctype == CTYPE_STRUCT:
228             node = self._create_typedef_struct(symbol)
229         elif ctype == CTYPE_UNION:
230             node = self._create_typedef_union(symbol)
231         elif ctype == CTYPE_ENUM:
232             return self._create_enum(symbol)
233         elif ctype in (CTYPE_TYPEDEF,
234                        CTYPE_POINTER,
235                        CTYPE_BASIC_TYPE,
236                        CTYPE_VOID):
237             if symbol.base_type.name:
238                 name = self.strip_namespace_object(symbol.ident)
239                 target = self.strip_namespace_object(symbol.base_type.name)
240                 return Alias(name, target, ctype=symbol.ident)
241             return None
242         else:
243             raise NotImplementedError(
244                 "symbol %r of type %s" % (symbol.ident, ctype_name(ctype)))
245         return node
246
247     def _create_type(self, source_type):
248         ctype = self._create_source_type(source_type)
249         if ctype == 'va_list':
250             raise SkipError
251         # FIXME: FILE* should not be skipped, it should be handled
252         #        properly instead
253         elif ctype == 'FILE*':
254             raise SkipError
255         type_name = type_name_from_ctype(ctype)
256         type_name = type_name.replace('*', '')
257         resolved_type_name = self.resolve_type_name(type_name)
258         return Type(resolved_type_name, ctype)
259
260     def _create_parameter(self, symbol, options):
261         ptype = self._create_type(symbol.base_type)
262         param = Parameter(symbol.ident, ptype)
263         for option in options:
264             if option in ['in-out', 'inout']:
265                 param.direction = 'inout'
266             elif option == 'in':
267                 param.direction = 'in'
268             elif option == 'out':
269                 param.direction = 'out'
270             elif option == 'callee-owns':
271                 param.transfer = True
272             elif option == 'allow-none':
273                 param.allow_none = True
274             else:
275                 print 'Unhandled parameter annotation option: %s' % (
276                     option, )
277         return param
278
279     def _create_return(self, source_type, options=None):
280         if not options:
281             options = []
282         rtype = self._create_type(source_type)
283         rtype = self.resolve_param_type(rtype)
284         return_ = Return(rtype)
285         for option in options:
286             if option == 'caller-owns':
287                 return_.transfer = True
288             elif option.startswith('seq '):
289                 value, element_options = option[3:].split(None, 2)
290                 c_element_type = self._parse_type_annotation(value)
291                 element_type = c_element_type.replace('*', '')
292                 element_type = self.resolve_type_name(element_type,
293                                                       c_element_type)
294                 seq = Sequence(rtype.name,
295                                type_name_from_ctype(rtype.name),
296                                element_type)
297                 seq.transfer = True
298                 return_.type = seq
299             else:
300                 print 'Unhandled parameter annotation option: %s' % (
301                     option, )
302         return return_
303
304     def _create_typedef_struct(self, symbol):
305         name = self._remove_prefix(symbol.ident)
306         name = self.strip_namespace_object(name)
307         struct = Struct(name, symbol.ident)
308         self._typedefs_ns[symbol.ident] = struct
309         return struct
310
311     def _create_typedef_union(self, symbol):
312         name = self._remove_prefix(symbol.ident)
313         name = self.strip_namespace_object(name)
314         union = Union(name, symbol.ident)
315         self._typedefs_ns[symbol.ident] = union
316         return union
317
318     def _create_struct(self, symbol):
319         struct = self._typedefs_ns.get(symbol.ident, None)
320         if struct is None:
321             # This is a bit of a hack; really we should try
322             # to resolve through the typedefs to find the real
323             # name
324             if symbol.ident.startswith('_'):
325                 name = symbol.ident[1:]
326             else:
327                 name = symbol.ident
328             name = self._remove_prefix(name)
329             name = self.strip_namespace_object(name)
330             name = self.resolve_type_name(name)
331             struct = Struct(name, symbol.ident)
332
333         for child in symbol.base_type.child_list:
334             field = self._traverse_one(child)
335             if field:
336                 struct.fields.append(field)
337
338         return struct
339
340     def _create_union(self, symbol):
341         union = self._typedefs_ns.get(symbol.ident, None)
342         if union is None:
343             # This is a bit of a hack; really we should try
344             # to resolve through the typedefs to find the real
345             # name
346             if symbol.ident.startswith('_'):
347                 name = symbol.ident[1:]
348             else:
349                 name = symbol.ident
350             name = self._remove_prefix(name)
351             name = self.strip_namespace_object(name)
352             name = self.resolve_type_name(name)
353             union = Union(name, symbol.ident)
354
355         for child in symbol.base_type.child_list:
356             field = self._traverse_one(child)
357             if field:
358                 union.fields.append(field)
359
360         return union
361
362     def _create_callback(self, symbol):
363         parameters = self._create_parameters(symbol.base_type.base_type)
364         retval = self._create_return(symbol.base_type.base_type.base_type)
365         name = self.strip_namespace_object(symbol.ident)
366         return Callback(name, retval, list(parameters), symbol.ident)
367
368     def _parse_type_annotation(self, annotation):
369         if (annotation[0] == "[" and
370             annotation[-1] == "]"):
371             return Sequence(self._parse_type_annotation(annotation[1:-1]))
372         return annotation
373
374     def _typepair_to_str(self, item):
375         nsname, item = item
376         if nsname is None:
377             return item.name
378         return '%s.%s' % (nsname, item.name)
379
380     def _resolve_type_name_1(self, type_name, ctype, names):
381         # First look using the built-in names
382         if ctype:
383             try:
384                 return type_names[ctype]
385             except KeyError, e:
386                 pass
387         try:
388             return type_names[type_name]
389         except KeyError, e:
390             pass
391         type_name = self.strip_namespace_object(type_name)
392         resolved = names.aliases.get(type_name)
393         if resolved:
394             return self._typepair_to_str(resolved)
395         resolved = names.names.get(type_name)
396         if resolved:
397             return self._typepair_to_str(resolved)
398         if ctype:
399             ctype = ctype.replace('*', '')
400             resolved = names.ctypes.get(ctype)
401             if resolved:
402                 return self._typepair_to_str(resolved)
403         raise KeyError("failed to find %r" % (type_name, ))
404
405     def resolve_type_name_full(self, type_name, ctype,
406                                names):
407         try:
408             return self._resolve_type_name_1(type_name, ctype, names)
409         except KeyError, e:
410             try:
411                 return self._resolve_type_name_1(type_name, ctype, self._names)
412             except KeyError, e:
413                 return type_name
414
415     def resolve_type_name(self, type_name, ctype=None):
416         try:
417             return self.resolve_type_name_full(type_name, ctype, self._names)
418         except KeyError, e:
419             return type_name
420
421     def gtypename_to_giname(self, gtname, names):
422         resolved = names.type_names.get(gtname)
423         if resolved:
424             return self._typepair_to_str(resolved)
425         resolved = self._names.type_names.get(gtname)
426         if resolved:
427             return self._typepair_to_str(resolved)
428         raise KeyError("Failed to resolve GType name: %r" % (gtname, ))
429
430     def ctype_of(self, obj):
431         if hasattr(obj, 'ctype'):
432             return obj.ctype
433         elif hasattr(obj, 'symbol'):
434             return obj.symbol
435         else:
436             return None
437
438     def resolve_param_type_full(self, ptype, names):
439         if isinstance(ptype, Sequence):
440             ptype.element_type = \
441                 self.resolve_param_type_full(ptype.element_type, names)
442         elif isinstance(ptype, Node):
443             ptype.name = self.resolve_type_name_full(ptype.name,
444                                                      self.ctype_of(ptype),
445                                                      names)
446         elif isinstance(ptype, basestring):
447             return self.resolve_type_name_full(ptype, None, names)
448         else:
449             raise AssertionError("Unhandled param: %r" % (ptype, ))
450         return ptype
451
452     def resolve_param_type(self, ptype):
453         try:
454             return self.resolve_param_type_full(ptype, self._names)
455         except KeyError, e:
456             return ptype